From a297d27d48b099db94dfa36dbe0a4c79131b768b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 19 Dec 2024 16:15:08 -0800 Subject: [PATCH] add brave tool provider (#653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Adds a new brave tool provider ## Test Plan ``` curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \ -H 'Content-Type: application/json' \ -d '{ "name": "search", "tool_group": { "type": "user_defined", "tools": [ { "name": "brave_search", "description": "A web search tool", "parameters": [ { "name": "query", "parameter_type": "string", "description": "The query to search" } ], "metadata": {}, "tool_prompt_format": "json" } ] } }' curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \ -H 'Content-Type: application/json' \ -d '{ "tool_id": "brave_search", "args": { "query": "who is meta ceo" } }' | jq .content % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 1973 100 1884 100 89 11288 533 --:--:-- --:--:-- --:--:-- 11885 "{'title': 'Mark Zuckerberg, Founder, Chairman and Chief Executive ...', 'url': 'https://about.meta.com/media-gallery/executives/mark-zuckerberg/', 'description': 'Not Logged In · Please log in to see this page', 'type': 'search_result'}\n{'title': 'Meta - Leadership & Governance', 'url': 'https://investor.fb.com/leadership-and-governance/', 'description': 'Mark Zuckerberg is the founder, chairman and CEO of Meta, which he originally founded as Facebook in 2004. Mark is responsible for setting the overall direction and product strategy for the company. He leads the design of Meta's services and development of its core technology and infrastructure.', 'type': 'search_result'}\n[{'type': 'video_result', 'url': 'https://www.facebook.com/etribune/videos/mark-zuckerberg-the-ceo-of-meta-has-officially-joined-the-exclusive-200-billion-/2372542949752515/', 'title': 'Mark Zuckerberg, the CEO of Meta, has officially joined the ...', 'description': \"Express Tribune, Karachi, Pakistan. 2,334,400 likes · 36,360 talking about this · 205 were here. The Express Tribune is Pakistan's #1 brand for breaking news in politics, sports, business, lifestyle\"}, {'type': 'video_result', 'url': 'https://www.youtube.com/watch?v=Y3oeQqtRvqk', 'title': \"Meta CEO: Mark Zuckerberg becomes World's Second Richest Person!\", 'description': 'Try VectorVest Risk-Free ➥➥➥ https://www.vectorvest.com/YTUse this link for a FREE Stock Analysis Report ➥➥➥ vectorvest.com/YTFSAVectorVest Merch Store ➥➥➥'}, {'type': 'video_result', 'url': 'https://m.facebook.com/abplive/videos/watch-meta-founder-and-ceo-mark-zuckerberg-recently-posted-a-heartwarming-video-/534841222497600/', 'title': '#WATCH | Meta founder and CEO Mark Zuckerberg recently ...', 'description': 'See posts, photos and more on Facebook'}]" curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \ -H 'Content-Type: application/json' -H 'X-LlamaStack-ProviderData: {"api_key": ""}' \ -d '{ "tool_id": "brave_search", "args": { "query": "who is meta ceo" } }' ``` --- llama_stack/apis/tools/tools.py | 8 +- llama_stack/distribution/routers/routers.py | 5 + .../distribution/routers/routing_tables.py | 20 ++- .../tool_runtime/brave_search/__init__.py | 20 +++ .../tool_runtime/brave_search/brave_search.py | 123 ++++++++++++++++++ .../tool_runtime/brave_search/config.py | 20 +++ .../tool_runtime/meta_reference/__init__.py | 14 -- .../meta_reference/meta_reference.py | 32 ----- .../providers/registry/tool_runtime.py | 24 +++- .../model_context_protocol/__init__.py | 21 +++ .../model_context_protocol}/config.py | 2 +- .../model_context_protocol.py | 85 ++++++++++++ 12 files changed, 316 insertions(+), 58 deletions(-) create mode 100644 llama_stack/providers/inline/tool_runtime/brave_search/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py create mode 100644 llama_stack/providers/inline/tool_runtime/brave_search/config.py delete mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py delete mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py create mode 100644 llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py rename llama_stack/providers/{inline/tool_runtime/meta_reference => remote/tool_runtime/model_context_protocol}/config.py (83%) create mode 100644 llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index c6b59e948..ce053fd66 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -26,11 +26,11 @@ class ToolParameter(BaseModel): @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value - name: str tool_group: str description: str parameters: List[ToolParameter] provider_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) @@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel): """ type: Literal["model_context_protocol"] = "model_context_protocol" + name: str endpoint: URL @json_schema_type class UserDefinedToolGroup(BaseModel): type: Literal["user_defined"] = "user_defined" + name: str tools: List[ToolDef] @@ -87,7 +89,6 @@ class Tools(Protocol): @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, - name: str, tool_group: ToolGroup, provider_id: Optional[str] = None, ) -> None: @@ -115,6 +116,9 @@ class Tools(Protocol): class ToolRuntime(Protocol): tool_store: ToolStore + @webmethod(route="/tool-runtime/discover", method="POST") + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ... + @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( self, tool_id: str, args: Dict[str, Any] diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 24fe89669..9c9cfec6f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime): tool_id=tool_id, args=args, ) + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + return await self.routing_table.get_provider_impl( + tool_group.name + ).discover_tools(tool_group) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index cc458c32a..690a4e9b7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable): await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.eval_task_store = self + elif api == Api.tool_runtime: + p.tool_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("Scoring", "scoring_function") elif isinstance(self, EvalTasksRoutingTable): return ("Eval", "eval_task") + elif isinstance(self, ToolsRoutingTable): + return ("Tools", "tool") else: raise ValueError("Unknown routing table type") @@ -476,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools): async def register_tool_group( self, - name: str, tool_group: ToolGroup, provider_id: Optional[str] = None, ) -> None: tools = [] if isinstance(tool_group, MCPToolGroup): - # TODO: first needs to be resolved to corresponding tools available in the MCP server - raise NotImplementedError("MCP tool provider not implemented yet") + # TODO: Actually find the right MCP provider + if provider_id is None: + raise ValueError("MCP provider_id not specified") + tools = await self.impls_by_provider_id[provider_id].discover_tools( + tool_group + ) + for tool in tools: + tool.provider_id = provider_id elif isinstance(tool_group, UserDefinedToolGroup): for tool in tool_group.tools: tools.append( Tool( identifier=tool.name, - tool_group=name, + tool_group=tool_group.name, name=tool.name, description=tool.description, parameters=tool.parameters, provider_id=provider_id, tool_prompt_format=tool.tool_prompt_format, provider_resource_id=tool.name, + metadata=tool.metadata, ) ) else: raise ValueError(f"Unknown tool group: {tool_group}") for tool in tools: - existing_tool = await self.get_tool(tool.name) + existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists if existing_tool: # Compare all fields except provider_id since that might be None in new obj diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py new file mode 100644 index 000000000..e9f0eeae8 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .brave_search import BraveSearchToolRuntimeImpl +from .config import BraveSearchToolConfig + + +class BraveSearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl(config: BraveSearchToolConfig, _deps): + impl = BraveSearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py new file mode 100644 index 000000000..464963b40 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import BraveSearchToolConfig + + +class BraveSearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: BraveSearchToolConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "brave_search": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + raise NotImplementedError("Brave search tool group not supported") + + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": api_key, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": args["query"]} + response = requests.get(url=url, params=payload, headers=headers) + response.raise_for_status() + results = self._clean_brave_response(response.json()) + content_items = "\n".join([str(result) for result in results]) + return ToolInvocationResult( + content=content_items, + ) + + def _clean_brave_response(self, search_response): + clean_response = [] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][: self.config.max_results]: + r_type = m["type"] + results = search_response[r_type]["results"] + cleaned = self._clean_result_by_type(r_type, results, m.get("index")) + clean_response.append(cleaned) + + return clean_response + + def _clean_result_by_type(self, r_type, results, idx=None): + type_cleaners = { + "web": ( + ["type", "title", "url", "description", "date", "extra_snippets"], + lambda x: x[idx], + ), + "faq": (["type", "question", "answer", "title", "url"], lambda x: x), + "infobox": ( + ["type", "title", "url", "description", "long_desc"], + lambda x: x[idx], + ), + "videos": (["type", "url", "title", "description", "date"], lambda x: x), + "locations": ( + [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ], + lambda x: x, + ), + "news": (["type", "title", "url", "description"], lambda x: x), + } + + if r_type not in type_cleaners: + return "" + + selected_keys, result_selector = type_cleaners[r_type] + results = result_selector(results) + + if isinstance(results, list): + cleaned = [ + {k: v for k, v in item.items() if k in selected_keys} + for item in results + ] + else: + cleaned = {k: v for k, v in results.items() if k in selected_keys} + + return str(cleaned) diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/config.py b/llama_stack/providers/inline/tool_runtime/brave_search/config.py new file mode 100644 index 000000000..565d428f7 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel, Field + + +class BraveSearchToolConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="The Brave Search API Key", + ) + max_results: int = Field( + default=3, + description="The maximum number of results to return", + ) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py b/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py deleted file mode 100644 index f7d52c1f0..000000000 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import MetaReferenceToolRuntimeConfig -from .meta_reference import MetaReferenceToolRuntimeImpl - - -async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps): - impl = MetaReferenceToolRuntimeImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py deleted file mode 100644 index 087fd918d..000000000 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime -from llama_stack.providers.datatypes import ToolsProtocolPrivate - -from .config import MetaReferenceToolRuntimeConfig - - -class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: MetaReferenceToolRuntimeConfig): - self.config = config - - async def initialize(self): - pass - - async def register_tool(self, tool: Tool): - print(f"registering tool {tool.identifier}") - pass - - async def unregister_tool(self, tool_id: str) -> None: - pass - - async def invoke_tool( - self, tool_id: str, args: Dict[str, Any] - ) -> ToolInvocationResult: - pass diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index c0e7a3d1b..f3e6aead8 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,16 +6,32 @@ from typing import List -from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.distribution.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.tool_runtime, - provider_type="inline::meta-reference", + provider_type="inline::brave-search", pip_packages=[], - module="llama_stack.providers.inline.tool_runtime.meta_reference", - config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig", + module="llama_stack.providers.inline.tool_runtime.brave_search", + config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + pip_packages=["mcp"], + ), ), ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py new file mode 100644 index 000000000..3b05f5632 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import ModelContextProtocolConfig + +from .model_context_protocol import ModelContextProtocolToolRuntimeImpl + + +class ModelContextProtocolToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): + impl = ModelContextProtocolToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py similarity index 83% rename from llama_stack/providers/inline/tool_runtime/meta_reference/config.py rename to llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index 3f6146c51..ffe4c9887 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -7,5 +7,5 @@ from pydantic import BaseModel -class MetaReferenceToolRuntimeConfig(BaseModel): +class ModelContextProtocolConfig(BaseModel): pass diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py new file mode 100644 index 000000000..0c6661731 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List +from urllib.parse import urlparse + +from llama_stack.apis.tools import ( + MCPToolGroup, + Tool, + ToolGroup, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate +from mcp import ClientSession +from mcp.client.sse import sse_client + +from .config import ModelContextProtocolConfig + + +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: ModelContextProtocolConfig): + self.config = config + + async def initialize(self): + pass + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + if not isinstance(tool_group, MCPToolGroup): + raise ValueError(f"Unsupported tool group type: {type(tool_group)}") + + tools = [] + async with sse_client(tool_group.endpoint.uri) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get( + "properties", {} + ).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + Tool( + identifier=tool.name, + description=tool.description, + tool_group=tool_group.name, + parameters=parameters, + metadata={ + "endpoint": tool_group.endpoint.uri, + }, + provider_resource_id=tool.name, + ) + ) + return tools + + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + tool = await self.tool_store.get_tool(tool_id) + if tool.metadata is None or tool.metadata.get("endpoint") is None: + raise ValueError(f"Tool {tool_id} does not have metadata") + endpoint = tool.metadata.get("endpoint") + if urlparse(endpoint).scheme not in ("http", "https"): + raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") + + async with sse_client(endpoint) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.call_tool(tool.identifier, args) + + return ToolInvocationResult( + content="\n".join([result.model_dump_json() for result in result.content]), + error_code=1 if result.isError else 0, + )