diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index c6b59e948..ce053fd66 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -26,11 +26,11 @@ class ToolParameter(BaseModel): @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value - name: str tool_group: str description: str parameters: List[ToolParameter] provider_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) @@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel): """ type: Literal["model_context_protocol"] = "model_context_protocol" + name: str endpoint: URL @json_schema_type class UserDefinedToolGroup(BaseModel): type: Literal["user_defined"] = "user_defined" + name: str tools: List[ToolDef] @@ -87,7 +89,6 @@ class Tools(Protocol): @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, - name: str, tool_group: ToolGroup, provider_id: Optional[str] = None, ) -> None: @@ -115,6 +116,9 @@ class Tools(Protocol): class ToolRuntime(Protocol): tool_store: ToolStore + @webmethod(route="/tool-runtime/discover", method="POST") + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ... + @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( self, tool_id: str, args: Dict[str, Any] diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 24fe89669..9c9cfec6f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime): tool_id=tool_id, args=args, ) + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + return await self.routing_table.get_provider_impl( + tool_group.name + ).discover_tools(tool_group) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index cc458c32a..690a4e9b7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable): await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.eval_task_store = self + elif api == Api.tool_runtime: + p.tool_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("Scoring", "scoring_function") elif isinstance(self, EvalTasksRoutingTable): return ("Eval", "eval_task") + elif isinstance(self, ToolsRoutingTable): + return ("Tools", "tool") else: raise ValueError("Unknown routing table type") @@ -476,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools): async def register_tool_group( self, - name: str, tool_group: ToolGroup, provider_id: Optional[str] = None, ) -> None: tools = [] if isinstance(tool_group, MCPToolGroup): - # TODO: first needs to be resolved to corresponding tools available in the MCP server - raise NotImplementedError("MCP tool provider not implemented yet") + # TODO: Actually find the right MCP provider + if provider_id is None: + raise ValueError("MCP provider_id not specified") + tools = await self.impls_by_provider_id[provider_id].discover_tools( + tool_group + ) + for tool in tools: + tool.provider_id = provider_id elif isinstance(tool_group, UserDefinedToolGroup): for tool in tool_group.tools: tools.append( Tool( identifier=tool.name, - tool_group=name, + tool_group=tool_group.name, name=tool.name, description=tool.description, parameters=tool.parameters, provider_id=provider_id, tool_prompt_format=tool.tool_prompt_format, provider_resource_id=tool.name, + metadata=tool.metadata, ) ) else: raise ValueError(f"Unknown tool group: {tool_group}") for tool in tools: - existing_tool = await self.get_tool(tool.name) + existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists if existing_tool: # Compare all fields except provider_id since that might be None in new obj diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py new file mode 100644 index 000000000..e9f0eeae8 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .brave_search import BraveSearchToolRuntimeImpl +from .config import BraveSearchToolConfig + + +class BraveSearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl(config: BraveSearchToolConfig, _deps): + impl = BraveSearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py new file mode 100644 index 000000000..464963b40 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import BraveSearchToolConfig + + +class BraveSearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: BraveSearchToolConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "brave_search": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + raise NotImplementedError("Brave search tool group not supported") + + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": api_key, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": args["query"]} + response = requests.get(url=url, params=payload, headers=headers) + response.raise_for_status() + results = self._clean_brave_response(response.json()) + content_items = "\n".join([str(result) for result in results]) + return ToolInvocationResult( + content=content_items, + ) + + def _clean_brave_response(self, search_response): + clean_response = [] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][: self.config.max_results]: + r_type = m["type"] + results = search_response[r_type]["results"] + cleaned = self._clean_result_by_type(r_type, results, m.get("index")) + clean_response.append(cleaned) + + return clean_response + + def _clean_result_by_type(self, r_type, results, idx=None): + type_cleaners = { + "web": ( + ["type", "title", "url", "description", "date", "extra_snippets"], + lambda x: x[idx], + ), + "faq": (["type", "question", "answer", "title", "url"], lambda x: x), + "infobox": ( + ["type", "title", "url", "description", "long_desc"], + lambda x: x[idx], + ), + "videos": (["type", "url", "title", "description", "date"], lambda x: x), + "locations": ( + [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ], + lambda x: x, + ), + "news": (["type", "title", "url", "description"], lambda x: x), + } + + if r_type not in type_cleaners: + return "" + + selected_keys, result_selector = type_cleaners[r_type] + results = result_selector(results) + + if isinstance(results, list): + cleaned = [ + {k: v for k, v in item.items() if k in selected_keys} + for item in results + ] + else: + cleaned = {k: v for k, v in results.items() if k in selected_keys} + + return str(cleaned) diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/config.py b/llama_stack/providers/inline/tool_runtime/brave_search/config.py new file mode 100644 index 000000000..565d428f7 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel, Field + + +class BraveSearchToolConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="The Brave Search API Key", + ) + max_results: int = Field( + default=3, + description="The maximum number of results to return", + ) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py b/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py deleted file mode 100644 index f7d52c1f0..000000000 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import MetaReferenceToolRuntimeConfig -from .meta_reference import MetaReferenceToolRuntimeImpl - - -async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps): - impl = MetaReferenceToolRuntimeImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py deleted file mode 100644 index 087fd918d..000000000 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime -from llama_stack.providers.datatypes import ToolsProtocolPrivate - -from .config import MetaReferenceToolRuntimeConfig - - -class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: MetaReferenceToolRuntimeConfig): - self.config = config - - async def initialize(self): - pass - - async def register_tool(self, tool: Tool): - print(f"registering tool {tool.identifier}") - pass - - async def unregister_tool(self, tool_id: str) -> None: - pass - - async def invoke_tool( - self, tool_id: str, args: Dict[str, Any] - ) -> ToolInvocationResult: - pass diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index c0e7a3d1b..f3e6aead8 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,16 +6,32 @@ from typing import List -from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.distribution.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.tool_runtime, - provider_type="inline::meta-reference", + provider_type="inline::brave-search", pip_packages=[], - module="llama_stack.providers.inline.tool_runtime.meta_reference", - config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig", + module="llama_stack.providers.inline.tool_runtime.brave_search", + config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + pip_packages=["mcp"], + ), ), ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py new file mode 100644 index 000000000..3b05f5632 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import ModelContextProtocolConfig + +from .model_context_protocol import ModelContextProtocolToolRuntimeImpl + + +class ModelContextProtocolToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): + impl = ModelContextProtocolToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py similarity index 83% rename from llama_stack/providers/inline/tool_runtime/meta_reference/config.py rename to llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index 3f6146c51..ffe4c9887 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -7,5 +7,5 @@ from pydantic import BaseModel -class MetaReferenceToolRuntimeConfig(BaseModel): +class ModelContextProtocolConfig(BaseModel): pass diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py new file mode 100644 index 000000000..0c6661731 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List +from urllib.parse import urlparse + +from llama_stack.apis.tools import ( + MCPToolGroup, + Tool, + ToolGroup, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate +from mcp import ClientSession +from mcp.client.sse import sse_client + +from .config import ModelContextProtocolConfig + + +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: ModelContextProtocolConfig): + self.config = config + + async def initialize(self): + pass + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + if not isinstance(tool_group, MCPToolGroup): + raise ValueError(f"Unsupported tool group type: {type(tool_group)}") + + tools = [] + async with sse_client(tool_group.endpoint.uri) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get( + "properties", {} + ).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + Tool( + identifier=tool.name, + description=tool.description, + tool_group=tool_group.name, + parameters=parameters, + metadata={ + "endpoint": tool_group.endpoint.uri, + }, + provider_resource_id=tool.name, + ) + ) + return tools + + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + tool = await self.tool_store.get_tool(tool_id) + if tool.metadata is None or tool.metadata.get("endpoint") is None: + raise ValueError(f"Tool {tool_id} does not have metadata") + endpoint = tool.metadata.get("endpoint") + if urlparse(endpoint).scheme not in ("http", "https"): + raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") + + async with sse_client(endpoint) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.call_tool(tool.identifier, args) + + return ToolInvocationResult( + content="\n".join([result.model_dump_json() for result in result.content]), + error_code=1 if result.isError else 0, + )