diff --git a/llama_stack/providers/impls/meta_reference/agents/__init__.py b/llama_stack/providers/impls/meta_reference/agents/__init__.py index c0844be3b..156de9a17 100644 --- a/llama_stack/providers/impls/meta_reference/agents/__init__.py +++ b/llama_stack/providers/impls/meta_reference/agents/__init__.py @@ -21,6 +21,7 @@ async def get_provider_impl( deps[Api.inference], deps[Api.memory], deps[Api.safety], + deps[Api.memory_banks], ) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 4dbc71dfa..5a209d0b7 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -137,3 +137,25 @@ class MetaReferenceAgentsImpl(Agents): agent = await self.get_agent(request.agent_id) async for event in agent.create_and_execute_turn(request): yield event + + async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn: + raise NotImplementedError() + + async def get_agents_step( + self, agent_id: str, turn_id: str, step_id: str + ) -> AgentStepResponse: + raise NotImplementedError() + + async def get_agents_session( + self, + agent_id: str, + session_id: str, + turn_ids: Optional[List[str]] = None, + ) -> Session: + raise NotImplementedError() + + async def delete_agents_session(self, agent_id: str, session_id: str) -> None: + raise NotImplementedError() + + async def delete_agents(self, agent_id: str) -> None: + raise NotImplementedError() diff --git a/llama_stack/providers/tests/agents/provider_config_example.yaml b/llama_stack/providers/tests/agents/provider_config_example.yaml new file mode 100644 index 000000000..5b643590c --- /dev/null +++ b/llama_stack/providers/tests/agents/provider_config_example.yaml @@ -0,0 +1,34 @@ +providers: + inference: + - provider_id: together + provider_type: remote::together + config: {} + - provider_id: tgi + provider_type: remote::tgi + config: + url: http://127.0.0.1:7001 +# - provider_id: meta-reference +# provider_type: meta-reference +# config: +# model: Llama-Guard-3-1B +# - provider_id: remote +# provider_type: remote +# config: +# host: localhost +# port: 7010 + safety: + - provider_id: together + provider_type: remote::together + config: {} + memory: + - provider_id: faiss + provider_type: meta-reference + config: {} + agents: + - provider_id: meta-reference + provider_type: meta-reference + config: + persistence_store: + namespace: null + type: sqlite + db_path: /Users/ashwin/.llama/runtime/kvstore.db diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 6b8ed5f03..3301dee96 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,17 +7,13 @@ import pytest import pytest_asyncio -from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403 - -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. +# 1. Ensure you have a conda environment with the right dependencies installed. +# This includes `pytest` and `pytest-asyncio`. # # 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. # @@ -38,22 +34,76 @@ async def agents_settings(): ) return { - "impl": impls[Api.safety], + "impl": impls[Api.agents], "memory_impl": impls[Api.memory], - "inference_impl": impls[Api.inference], - "safety_impl": impls[Api.safety], + "common_params": { + "model": "Llama3.1-8B-Instruct", + "instructions": "You are a helpful assistant.", + }, } @pytest.fixture -def sample_tool_definition(): - return ToolDefinition( - tool_name="get_weather", - description="Get the current weather", - parameters={ - "location": ToolParamDefinition( - param_type="string", - description="The city and state, e.g. San Francisco, CA", - ), - }, +def sample_messages(): + return [ + UserMessage(content="What's the weather like today?"), + ] + + +@pytest.mark.asyncio +async def test_create_agent_turn(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) + + # Check for expected event types + event_types = [chunk.event.payload.event_type for chunk in turn_response] + assert AgentTurnResponseEventType.turn_start.value in event_types + assert AgentTurnResponseEventType.step_start.value in event_types + assert AgentTurnResponseEventType.step_complete.value in event_types + assert AgentTurnResponseEventType.turn_complete.value in event_types + + # Check the final turn complete event + final_event = turn_response[-1].event.payload + assert isinstance(final_event, AgentTurnResponseTurnCompletePayload) + assert isinstance(final_event.turn, Turn) + assert final_event.turn.session_id == session_id + assert final_event.turn.input_messages == sample_messages + assert isinstance(final_event.turn.output_message, CompletionMessage) + assert len(final_event.turn.output_message.content) > 0 diff --git a/llama_stack/providers/tests/inference/provider_config_example.yaml b/llama_stack/providers/tests/inference/provider_config_example.yaml index 8431b01ac..c4bb4af16 100644 --- a/llama_stack/providers/tests/inference/provider_config_example.yaml +++ b/llama_stack/providers/tests/inference/provider_config_example.yaml @@ -21,5 +21,4 @@ providers: # this is a place to provide such data. provider_data: "test-together": - together_api_key: - 0xdeadbeefputrealapikeyhere + together_api_key: 0xdeadbeefputrealapikeyhere diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 156fde2dd..0afc894cf 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str): scope="session", params=[ {"model": Llama_8B}, - # {"model": Llama_3B}, + {"model": Llama_3B}, ], ids=lambda d: d["model"], )