diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index f2602ddde..25de35497 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -54,6 +54,7 @@ class ToolDefinitionCommon(BaseModel): class SearchEngineType(Enum): bing = "bing" brave = "brave" + tavily = "tavily" @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py index 4c9cdfcd2..a1e7d08f5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py @@ -86,10 +86,13 @@ class PhotogenTool(SingleMessageBuiltinTool): class SearchTool(SingleMessageBuiltinTool): def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None: self.api_key = api_key + self.engine_type = engine if engine == SearchEngineType.bing: self.engine = BingSearch(api_key, **kwargs) elif engine == SearchEngineType.brave: self.engine = BraveSearch(api_key, **kwargs) + elif engine == SearchEngineType.tavily: + self.engine = TavilySearch(api_key, **kwargs) else: raise ValueError(f"Unknown search engine: {engine}") @@ -257,6 +260,21 @@ class BraveSearch: return {"query": query, "top_k": clean_response} +class TavilySearch: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + async def search(self, query: str) -> str: + response = requests.post( + "https://api.tavily.com/search", + json={"api_key": self.api_key, "query": query}, + ) + return json.dumps(self._clean_tavily_response(response.json())) + + def _clean_tavily_response(self, search_response, top_k=3): + return {"query": search_response["query"], "top_k": search_response["results"]} + + class WolframAlphaTool(SingleMessageBuiltinTool): def __init__(self, api_key: str) -> None: self.api_key = api_key diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 60c047058..ee2f3d29f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -68,6 +68,73 @@ def query_attachment_messages(): ] +async def create_agent_turn_with_search_tool( + agents_stack: Dict[str, object], + search_query_messages: List[object], + common_params: Dict[str, str], + search_tool_definition: SearchToolDefinition, +) -> None: + """ + Create an agent turn with a search tool. + + Args: + agents_stack (Dict[str, object]): The agents stack. + search_query_messages (List[object]): The search query messages. + common_params (Dict[str, str]): The common parameters. + search_tool_definition (SearchToolDefinition): The search tool definition. + """ + + # Create an agent with the search tool + agent_config = AgentConfig( + **{ + **common_params, + "tools": [search_tool_definition], + } + ) + + agent_id, session_id = await create_agent_session( + agents_stack.impls[Api.agents], agent_config + ) + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=search_query_messages, + stream=True, + ) + + turn_response = [ + chunk + async for chunk in await agents_stack.impls[Api.agents].create_agent_turn( + **turn_request + ) + ] + + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) + + check_event_types(turn_response) + + # Check for tool execution events + tool_execution_events = [ + chunk + for chunk in turn_response + if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) + and chunk.event.payload.step_details.step_type == StepType.tool_execution.value + ] + assert len(tool_execution_events) > 0, "No tool execution events found" + + # Check the tool execution details + tool_execution = tool_execution_events[0].event.payload.step_details + assert isinstance(tool_execution, ToolExecutionStep) + assert len(tool_execution.tool_calls) > 0 + assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search + assert len(tool_execution.tool_responses) > 0 + + check_turn_complete_event(turn_response, session_id, search_query_messages) + + class TestAgents: @pytest.mark.asyncio async def test_agent_turns_with_safety( @@ -215,63 +282,34 @@ class TestAgents: async def test_create_agent_turn_with_brave_search( self, agents_stack, search_query_messages, common_params ): - agents_impl = agents_stack.impls[Api.agents] - if "BRAVE_SEARCH_API_KEY" not in os.environ: pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") - # Create an agent with Brave search tool - agent_config = AgentConfig( - **{ - **common_params, - "tools": [ - SearchToolDefinition( - type=AgentTool.brave_search.value, - api_key=os.environ["BRAVE_SEARCH_API_KEY"], - engine=SearchEngineType.brave, - ) - ], - } + search_tool_definition = SearchToolDefinition( + type=AgentTool.brave_search.value, + api_key=os.environ["BRAVE_SEARCH_API_KEY"], + engine=SearchEngineType.brave, + ) + await create_agent_turn_with_search_tool( + agents_stack, search_query_messages, common_params, search_tool_definition ) - agent_id, session_id = await create_agent_session(agents_impl, agent_config) - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=search_query_messages, - stream=True, + @pytest.mark.asyncio + async def test_create_agent_turn_with_tavily_search( + self, agents_stack, search_query_messages, common_params + ): + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + + search_tool_definition = SearchToolDefinition( + type=AgentTool.brave_search.value, # place holder only + api_key=os.environ["TAVILY_SEARCH_API_KEY"], + engine=SearchEngineType.tavily, ) - - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] - - assert len(turn_response) > 0 - assert all( - isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + await create_agent_turn_with_search_tool( + agents_stack, search_query_messages, common_params, search_tool_definition ) - check_event_types(turn_response) - - # Check for tool execution events - tool_execution_events = [ - chunk - for chunk in turn_response - if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) - and chunk.event.payload.step_details.step_type - == StepType.tool_execution.value - ] - assert len(tool_execution_events) > 0, "No tool execution events found" - - # Check the tool execution details - tool_execution = tool_execution_events[0].event.payload.step_details - assert isinstance(tool_execution, ToolExecutionStep) - assert len(tool_execution.tool_calls) > 0 - assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search - assert len(tool_execution.tool_responses) > 0 - - check_turn_complete_event(turn_response, session_id, search_query_messages) - def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response]