diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index f88e4c4c2..2f258e620 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -70,10 +70,10 @@ class TogetherInferenceAdapter( model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -96,6 +96,7 @@ class TogetherInferenceAdapter( tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -130,11 +131,23 @@ class TogetherInferenceAdapter( yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: + options = get_sampling_options(request) + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + return { "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **options, } async def embeddings( diff --git a/llama_stack/providers/tests/agents/test_agent_persistence.py b/llama_stack/providers/tests/agents/test_agent_persistence.py new file mode 100644 index 000000000..a15887b33 --- /dev/null +++ b/llama_stack/providers/tests/agents/test_agent_persistence.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.providers.tests.resolver import resolve_impls_for_test +from llama_stack.providers.datatypes import * # noqa: F403 + +from dotenv import load_dotenv + +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig + +# How to run this test: +# +# 1. Ensure you have a conda environment with the right dependencies installed. +# This includes `pytest` and `pytest-asyncio`. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \ +# --tb=short --disable-warnings +# ``` + +load_dotenv() + + +@pytest_asyncio.fixture(scope="session") +async def agents_settings(): + impls = await resolve_impls_for_test( + Api.agents, deps=[Api.inference, Api.memory, Api.safety] + ) + + return { + "impl": impls[Api.agents], + "memory_impl": impls[Api.memory], + "common_params": { + "model": "Llama3.1-8B-Instruct", + "instructions": "You are a helpful assistant.", + }, + } + + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What's the weather like today?"), + ] + + +@pytest.mark.asyncio +async def test_delete_agents_and_sessions(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + persistence_store = await kvstore_impl(agents_settings["persistence"]) + + await agents_impl.delete_agents_session(agent_id, session_id) + session_response = await persistence_store.get(f"session:{agent_id}:{session_id}") + + await agents_impl.delete_agents(agent_id) + agent_response = await persistence_store.get(f"agent:{agent_id}") + + assert session_response is None + assert agent_response is None + + +async def test_get_agent_turns_and_steps(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + final_event = turn_response[-1].event.payload + turn_id = final_event.turn.turn_id + persistence_store = await kvstore_impl(SqliteKVStoreConfig()) + turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) + + assert isinstance(response, Turn) + assert response == final_event.turn + assert turn == final_event.turn + + steps = final_event.turn.steps + step_id = steps[0].step_id + step_response = await agents_impl.get_agents_step( + agent_id, session_id, turn_id, step_id + ) + + assert isinstance(step_response.step, Step) + assert step_response.step == steps[0] diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index e89f672b1..ad49448e2 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -195,6 +195,7 @@ async def test_structured_output(inference_settings): "meta-reference", "remote::fireworks", "remote::tgi", + "remote::together", ): pytest.skip("Other inference providers don't support structured output yet") diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 3a1e25d65..2c222ffa1 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -7,8 +7,10 @@ import unittest from llama_models.llama3.api import * # noqa: F403 -from llama_stack.inference.api import * # noqa: F403 -from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages +from llama_stack.apis.inference.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) MODEL = "Llama3.1-8B-Instruct" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index add29da99..22ae8a717 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -34,8 +34,6 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict: if params := request.sampling_params: for attr in {"temperature", "top_p", "top_k", "max_tokens"}: if getattr(params, attr): - if attr == "max_tokens": - options["num_predict"] = getattr(params, attr) options[attr] = getattr(params, attr) if params.repetition_penalty is not None and params.repetition_penalty != 1.0: