diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 3301dee96..edcc6adea 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -4,11 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os + import pytest import pytest_asyncio from llama_stack.apis.agents import * # noqa: F403 from llama_stack.providers.tests.resolver import resolve_impls_for_test +from llama_stack.providers.datatypes import * # noqa: F403 + +from dotenv import load_dotenv # How to run this test: # @@ -26,6 +31,8 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test # --tb=short --disable-warnings # ``` +load_dotenv() + @pytest_asyncio.fixture(scope="session") async def agents_settings(): @@ -50,6 +57,13 @@ def sample_messages(): ] +@pytest.fixture +def search_query_messages(): + return [ + UserMessage(content="What are the latest developments in quantum computing?"), + ] + + @pytest.mark.asyncio async def test_create_agent_turn(agents_settings, sample_messages): agents_impl = agents_settings["impl"] @@ -107,3 +121,90 @@ async def test_create_agent_turn(agents_settings, sample_messages): assert final_event.turn.input_messages == sample_messages assert isinstance(final_event.turn.output_message, CompletionMessage) assert len(final_event.turn.output_message.content) > 0 + + +@pytest.mark.asyncio +async def test_create_agent_turn_with_brave_search( + agents_settings, search_query_messages +): + agents_impl = agents_settings["impl"] + + if "BRAVE_SEARCH_API_KEY" not in os.environ: + pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") + + # Create an agent with Brave search tool + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[ + SearchToolDefinition( + type=AgentTool.brave_search.value, + api_key=os.environ["BRAVE_SEARCH_API_KEY"], + engine=SearchEngineType.brave, + ) + ], + tool_choice=ToolChoice.auto, + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session with Brave Search" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=search_query_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) + + # Check for expected event types + event_types = [chunk.event.payload.event_type for chunk in turn_response] + assert AgentTurnResponseEventType.turn_start.value in event_types + assert AgentTurnResponseEventType.step_start.value in event_types + assert AgentTurnResponseEventType.step_complete.value in event_types + assert AgentTurnResponseEventType.turn_complete.value in event_types + + # Check for tool execution events + tool_execution_events = [ + chunk + for chunk in turn_response + if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) + and chunk.event.payload.step_details.step_type == StepType.tool_execution.value + ] + assert len(tool_execution_events) > 0, "No tool execution events found" + + # Check the tool execution details + tool_execution = tool_execution_events[0].event.payload.step_details + assert isinstance(tool_execution, ToolExecutionStep) + assert len(tool_execution.tool_calls) > 0 + assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search + assert len(tool_execution.tool_responses) > 0 + + # Check the final turn complete event + final_event = turn_response[-1].event.payload + assert isinstance(final_event, AgentTurnResponseTurnCompletePayload) + assert isinstance(final_event.turn, Turn) + assert final_event.turn.session_id == session_id + assert final_event.turn.input_messages == search_query_messages + assert isinstance(final_event.turn.output_message, CompletionMessage) + assert len(final_event.turn.output_message.content) > 0