diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index cd44ad570..5884c2a78 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -14,7 +14,11 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403 -from ..agent_instance import ChatAgent +from ..agents import ( + AGENT_INSTANCES_BY_ID, + MetaReferenceAgentsImpl, + MetaReferenceImplConfig, +) class MockInferenceAPI: @@ -42,10 +46,22 @@ class MockInferenceAPI: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type="progress", - delta="Mock response", + delta="AI is a fascinating field...", + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type="progress", + delta=ToolCallDelta( + content=ToolCall( + call_id="123", + tool_name=BuiltinTool.brave_search.value, + arguments={"query": "AI history"}, + ), + parse_status="success", + ), ) ) - yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type="complete", @@ -149,24 +165,32 @@ def mock_memory_api(): @pytest.fixture -def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): - # You'll need to adjust this based on the actual ChatAgent constructor +async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): + impl = MetaReferenceAgentsImpl( + config=MetaReferenceImplConfig(), + inference_api=mock_inference_api, + safety_api=mock_safety_api, + memory_api=mock_memory_api, + ) + await impl.initialize() + agent_config = AgentConfig( model="test_model", instructions="You are a helpful assistant.", sampling_params=SamplingParams(), - tools=[], + tools=[ + SearchToolDefinition( + name="brave_search", + api_key="test_key", + ), + ], tool_choice=ToolChoice.auto, input_shields=[], output_shields=[], ) - return ChatAgent( - agent_config=agent_config, - inference_api=mock_inference_api, - memory_api=mock_memory_api, - safety_api=mock_safety_api, - builtin_tools=[], - ) + response = await impl.create_agent(agent_config) + agent = AGENT_INSTANCES_BY_ID[response.agent_id] + return agent @pytest.mark.asyncio @@ -214,3 +238,68 @@ async def test_run_shields_wrapper(chat_agent): assert len(responses) == 2 # StepStart, StepComplete assert responses[0].event.payload.step_type.value == "shield_call" assert not responses[1].event.payload.step_details.response.is_violation + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily") +async def test_chat_agent_complex_turn(chat_agent): + # Setup + session = chat_agent.create_session("Test Session") + request = AgentTurnCreateRequest( + agent_id="random", + session_id=session.session_id, + messages=[UserMessage(content="Tell me about AI and then use a tool.")], + stream=True, + ) + + # Execute the turn + responses = [] + async for response in chat_agent.create_and_execute_turn(request): + responses.append(response) + + # Assertions + assert len(responses) > 0 + + # Check for the presence of different step types + step_types = [ + response.event.payload.step_type + for response in responses + if hasattr(response.event.payload, "step_type") + ] + + assert "shield_call" in step_types, "Shield call step is missing" + assert "inference" in step_types, "Inference step is missing" + assert "tool_execution" in step_types, "Tool execution step is missing" + + # Check for the presence of start and complete events + event_types = [ + response.event.payload.event_type + for response in responses + if hasattr(response.event.payload, "event_type") + ] + assert "start" in event_types, "Start event is missing" + assert "complete" in event_types, "Complete event is missing" + + # Check for the presence of tool call + tool_calls = [ + response.event.payload.tool_call + for response in responses + if hasattr(response.event.payload, "tool_call") + ] + assert any( + tool_call + for tool_call in tool_calls + if tool_call and tool_call.content.get("name") == "memory" + ), "Memory tool call is missing" + + # Check for the final turn complete event + assert any( + isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) + for response in responses + ), "Turn complete event is missing" + + # Verify the turn was added to the session + assert len(session.turns) == 1, "Turn was not added to the session" + assert ( + session.turns[0].input_messages == request.messages + ), "Input messages do not match"