diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py b/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py new file mode 100644 index 000000000..8061a250c --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import TavilySearchToolConfig +from .tavily_search import TavilySearchToolRuntimeImpl + + +class TavilySearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl(config: TavilySearchToolConfig, _deps): + impl = TavilySearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/config.py b/llama_stack/providers/inline/tool_runtime/tavily_search/config.py new file mode 100644 index 000000000..f7a8f3f09 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel, Field + + +class TavilySearchToolConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="The Tavily Search API Key", + ) + max_results: int = Field( + default=3, + description="The maximum number of results to return", + ) diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py new file mode 100644 index 000000000..f80d10dfe --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import TavilySearchToolConfig + + +class TavilySearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: TavilySearchToolConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "tavily_search": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: + raise NotImplementedError("Tavily search tool group not supported") + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + response = requests.post( + "https://api.tavily.com/search", + json={"api_key": api_key, "query": args["query"]}, + ) + print(f"================= Tavily response: {response.json()}") + + return ToolInvocationResult( + content=json.dumps(self._clean_tavily_response(response.json())) + ) + + def _clean_tavily_response(self, search_response, top_k=3): + return {"query": search_response["query"], "top_k": search_response["results"]} diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index d0493810c..9058fb718 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -33,6 +33,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", api_dependencies=[Api.memory, Api.memory_banks, Api.inference], ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::tavily-search", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.tavily_search", + config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index f5158b57c..c0690e4e3 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -77,6 +77,13 @@ def tool_runtime_memory() -> ProviderFixture: "api_key": os.environ["BRAVE_SEARCH_API_KEY"], }, ), + Provider( + provider_id="tavily-search", + provider_type="inline::tavily-search", + config={ + "api_key": os.environ["TAVILY_SEARCH_API_KEY"], + }, + ), ], ) @@ -146,13 +153,41 @@ async def agents_stack(request, inference_model, safety_shield): ToolDef( name="brave_search", description="brave_search", - parameters=[], + parameters=[ + ToolParameter( + name="query", + description="query", + parameter_type="string", + required=True, + ), + ], metadata={}, ), ], ), provider_id="brave-search", ), + ToolGroupInput( + tool_group_id="tavily_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="tavily_search", + description="tavily_search", + parameters=[ + ToolParameter( + name="query", + description="query", + parameter_type="string", + required=True, + ), + ], + metadata={}, + ), + ], + ), + provider_id="tavily-search", + ), ToolGroupInput( tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 78ca2341f..cd4f75418 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -149,7 +149,7 @@ async def create_agent_turn_with_search_tool( tool_execution = tool_execution_events[0].event.payload.step_details assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 - assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search + assert tool_execution.tool_calls[0].tool_name == tool_name assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) @@ -302,6 +302,20 @@ class TestAgents: "brave_search", ) + @pytest.mark.asyncio + async def test_create_agent_turn_with_tavily_search( + self, agents_stack, search_query_messages, common_params + ): + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + + await create_agent_turn_with_search_tool( + agents_stack, + search_query_messages, + common_params, + "tavily_search", + ) + def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response]