diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 24845e101..0c2bb5863 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -53,6 +53,7 @@ class ToolDef(BaseModel): description: Optional[str] = None parameters: Optional[List[ToolParameter]] = None metadata: Optional[Dict[str, Any]] = None + built_in_type: Optional[BuiltinTool] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4ed932807..2f0288865 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -527,6 +527,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): provider_resource_id=tool_def.name, metadata=tool_def.metadata, tool_host=tool_host, + built_in_type=tool_def.built_in_type, ) ) for tool in tools: diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index fc568996d..3b3f18032 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -9,6 +9,8 @@ import logging import tempfile from typing import Any, Dict, List, Optional +from llama_models.llama3.api.datatypes import BuiltinTool + from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( Tool, @@ -56,6 +58,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): parameter_type="string", ), ], + built_in_type=BuiltinTool.code_interpreter, ) ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index d6e892599..40299edad 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -42,6 +42,16 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", ), ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="bing-search", + module="llama_stack.providers.remote.tool_runtime.bing_search", + config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", + ), + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( @@ -52,6 +62,16 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", ), ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="wolfram-alpha", + module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", + config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", + ), + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py b/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py new file mode 100644 index 000000000..8481737b5 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .bing_search import BingSearchToolRuntimeImpl +from .config import BingSearchToolConfig + +__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"] +from pydantic import BaseModel + + +class BingSearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: BingSearchToolConfig, _deps): + impl = BingSearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py new file mode 100644 index 000000000..b0c30b0a0 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List, Optional + +import requests +from llama_models.llama3.api.datatypes import BuiltinTool + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import BingSearchToolConfig + + +class BingSearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: BingSearchToolConfig): + self.config = config + self.url = "https://api.bing.microsoft.com/v7.0/search" + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + pass + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="web_search", + description="Search the web using Bing Search API", + parameters=[ + ToolParameter( + name="query", + description="The query to search for", + parameter_type="string", + ) + ], + built_in_type=BuiltinTool.brave_search, + ) + ] + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + headers = { + "Ocp-Apim-Subscription-Key": api_key, + } + params = { + "count": self.config.top_k, + "textDecorations": True, + "textFormat": "HTML", + "q": args["query"], + } + + response = requests.get( + url=self.url, + params=params, + headers=headers, + ) + response.raise_for_status() + + return ToolInvocationResult( + content=json.dumps(self._clean_response(response.json())) + ) + + def _clean_response(self, search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append( + {k: v for k, v in p.items() if k in selected_keys} + ) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + + clean_response.append(clean_news) + + return {"query": query, "top_k": clean_response} diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/config.py b/llama_stack/providers/remote/tool_runtime/bing_search/config.py new file mode 100644 index 000000000..67283d8d5 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/bing_search/config.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + + +class BingSearchToolConfig(BaseModel): + """Configuration for Bing Search Tool Runtime""" + + api_key: Optional[str] = None + top_k: int = 3 diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 162e82d62..dab6ce439 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional import requests +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -62,6 +63,7 @@ class BraveSearchToolRuntimeImpl( parameter_type="string", ) ], + built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 6dc515be3..d22f188b3 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -8,6 +8,7 @@ import json from typing import Any, Dict, List, Optional import requests +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -63,6 +64,7 @@ class TavilySearchToolRuntimeImpl( parameter_type="string", ) ], + built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py new file mode 100644 index 000000000..aaa6e4e69 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import WolframAlphaToolConfig +from .wolfram_alpha import WolframAlphaToolRuntimeImpl + +__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"] + + +class WolframAlphaToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: WolframAlphaToolConfig, _deps): + impl = WolframAlphaToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py new file mode 100644 index 000000000..13996b639 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + + +class WolframAlphaToolConfig(BaseModel): + """Configuration for WolframAlpha Tool Runtime""" + + api_key: Optional[str] = None diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py new file mode 100644 index 000000000..0f3fdfb39 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List, Optional + +import requests +from llama_models.llama3.api.datatypes import BuiltinTool + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import WolframAlphaToolConfig + + +class WolframAlphaToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: WolframAlphaToolConfig): + self.config = config + self.url = "https://api.wolframalpha.com/v2/query" + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + pass + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="wolfram_alpha", + description="Query WolframAlpha for computational knowledge", + parameters=[ + ToolParameter( + name="query", + description="The query to compute", + parameter_type="string", + ) + ], + built_in_type=BuiltinTool.wolfram_alpha, + ) + ] + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + params = { + "input": args["query"], + "appid": api_key, + "format": "plaintext", + "output": "json", + } + response = requests.get( + self.url, + params=params, + ) + + return ToolInvocationResult( + content=json.dumps(self._clean_wolfram_alpha_response(response.json())) + ) + + def _clean_wolfram_alpha_response(self, wa_response): + remove = { + "queryresult": [ + "datatypes", + "error", + "timedout", + "timedoutpods", + "numpods", + "timing", + "parsetiming", + "parsetimedout", + "recalculate", + "id", + "host", + "server", + "related", + "version", + { + "pods": [ + "scanner", + "id", + "error", + "expressiontypes", + "states", + "infos", + "position", + "numsubpods", + ] + }, + "assumptions", + ], + } + for main_key in remove: + for key_to_remove in remove[main_key]: + try: + if key_to_remove == "assumptions": + if "assumptions" in wa_response[main_key]: + del wa_response[main_key][key_to_remove] + if isinstance(key_to_remove, dict): + for sub_key in key_to_remove: + if sub_key == "pods": + for i in range(len(wa_response[main_key][sub_key])): + if ( + wa_response[main_key][sub_key][i]["title"] + == "Result" + ): + del wa_response[main_key][sub_key][i + 1 :] + break + sub_items = wa_response[main_key][sub_key] + for i in range(len(sub_items)): + for sub_key_to_remove in key_to_remove[sub_key]: + if sub_key_to_remove in sub_items[i]: + del sub_items[i][sub_key_to_remove] + elif key_to_remove in wa_response[main_key]: + del wa_response[main_key][key_to_remove] + except KeyError: + pass + return wa_response diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index a9f923c87..a559dbf8c 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -33,6 +33,13 @@ def tool_runtime_memory_and_search() -> ProviderFixture: "api_key": os.environ["TAVILY_SEARCH_API_KEY"], }, ), + Provider( + provider_id="wolfram-alpha", + provider_type="remote::wolfram-alpha", + config={ + "api_key": os.environ["WOLFRAM_ALPHA_API_KEY"], + }, + ), ], ) @@ -53,12 +60,24 @@ def tool_group_input_tavily_search() -> ToolGroupInput: ) +@pytest.fixture(scope="session") +def tool_group_input_wolfram_alpha() -> ToolGroupInput: + return ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ) + + TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") async def tools_stack( - request, inference_model, tool_group_input_memory, tool_group_input_tavily_search + request, + inference_model, + tool_group_input_memory, + tool_group_input_tavily_search, + tool_group_input_wolfram_alpha, ): fixture_dict = request.param @@ -104,6 +123,7 @@ async def tools_stack( models=models, tool_groups=[ tool_group_input_tavily_search, + tool_group_input_wolfram_alpha, tool_group_input_memory, ], ) diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 917db55e1..16081b939 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -20,6 +20,11 @@ def sample_search_query(): return "What are the latest developments in quantum computing?" +@pytest.fixture +def sample_wolfram_alpha_query(): + return "What is the square root of 16?" + + @pytest.fixture def sample_documents(): urls = [ @@ -61,6 +66,24 @@ class TestTools: assert len(response.content) > 0 assert isinstance(response.content, str) + @pytest.mark.asyncio + async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query): + """Test the wolfram alpha tool functionality.""" + if "WOLFRAM_ALPHA_API_KEY" not in os.environ: + pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") + + tools_impl = tools_stack.impls[Api.tool_runtime] + + response = await tools_impl.invoke_tool( + tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query} + ) + + # Verify the response + assert isinstance(response, ToolInvocationResult) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) + @pytest.mark.asyncio async def test_memory_tool(self, tools_stack, sample_documents): """Test the memory tool functionality."""