From e95c168bc09047f71d55a3a79c5e8c427c4a753e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 19 Dec 2024 16:14:23 -0800 Subject: [PATCH] add model context protocol provider (#665) # What does this PR do? Changes: * Adds a new API to discover tools available on a runtime * Adds a new model context protocol provider ## Test Plan ``` # clone python sdk for mcp and start the simple-tool server uv run mcp-simple-tool --transport sse --port 56000 curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \ -H 'Content-Type: application/json' \ -d '{ "tool_group": { "name": "simple_mcp_group", "type": "model_context_protocol", "endpoint": {"uri": "http://localhost:56000/sse"} }, "provider_id": "model-context-protocol" }' curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \ -H 'Content-Type: application/json' \ -d '{ "tool_id": "fetch", "args": { "url": "http://google.com/" } }' ``` --- llama_stack/apis/tools/tools.py | 8 +- llama_stack/distribution/routers/routers.py | 5 ++ .../distribution/routers/routing_tables.py | 16 ++-- .../tool_runtime/brave_search/brave_search.py | 7 +- .../providers/registry/tool_runtime.py | 17 +++- .../model_context_protocol/__init__.py | 21 +++++ .../model_context_protocol/config.py | 11 +++ .../model_context_protocol.py | 85 +++++++++++++++++++ 8 files changed, 160 insertions(+), 10 deletions(-) create mode 100644 llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py create mode 100644 llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py create mode 100644 llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index c6b59e948..ce053fd66 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -26,11 +26,11 @@ class ToolParameter(BaseModel): @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value - name: str tool_group: str description: str parameters: List[ToolParameter] provider_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) @@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel): """ type: Literal["model_context_protocol"] = "model_context_protocol" + name: str endpoint: URL @json_schema_type class UserDefinedToolGroup(BaseModel): type: Literal["user_defined"] = "user_defined" + name: str tools: List[ToolDef] @@ -87,7 +89,6 @@ class Tools(Protocol): @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, - name: str, tool_group: ToolGroup, provider_id: Optional[str] = None, ) -> None: @@ -115,6 +116,9 @@ class Tools(Protocol): class ToolRuntime(Protocol): tool_store: ToolStore + @webmethod(route="/tool-runtime/discover", method="POST") + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ... + @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( self, tool_id: str, args: Dict[str, Any] diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 24fe89669..9c9cfec6f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime): tool_id=tool_id, args=args, ) + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + return await self.routing_table.get_provider_impl( + tool_group.name + ).discover_tools(tool_group) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 556edc434..690a4e9b7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -480,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools): async def register_tool_group( self, - name: str, tool_group: ToolGroup, provider_id: Optional[str] = None, ) -> None: tools = [] if isinstance(tool_group, MCPToolGroup): - # TODO: first needs to be resolved to corresponding tools available in the MCP server - raise NotImplementedError("MCP tool provider not implemented yet") + # TODO: Actually find the right MCP provider + if provider_id is None: + raise ValueError("MCP provider_id not specified") + tools = await self.impls_by_provider_id[provider_id].discover_tools( + tool_group + ) + for tool in tools: + tool.provider_id = provider_id elif isinstance(tool_group, UserDefinedToolGroup): for tool in tool_group.tools: tools.append( Tool( identifier=tool.name, - tool_group=name, + tool_group=tool_group.name, name=tool.name, description=tool.description, parameters=tool.parameters, provider_id=provider_id, tool_prompt_format=tool.tool_prompt_format, provider_resource_id=tool.name, + metadata=tool.metadata, ) ) else: raise ValueError(f"Unknown tool group: {tool_group}") for tool in tools: - existing_tool = await self.get_tool(tool.name) + existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists if existing_tool: # Compare all fields except provider_id since that might be None in new obj diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py index cb673d88f..464963b40 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any, Dict, List import requests -from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -42,6 +42,9 @@ class BraveSearchToolRuntimeImpl( ) return provider_data.api_key + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + raise NotImplementedError("Brave search tool group not supported") + async def invoke_tool( self, tool_id: str, args: Dict[str, Any] ) -> ToolInvocationResult: diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index a732845be..f3e6aead8 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.distribution.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: @@ -19,4 +25,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + pip_packages=["mcp"], + ), + ), ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py new file mode 100644 index 000000000..3b05f5632 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import ModelContextProtocolConfig + +from .model_context_protocol import ModelContextProtocolToolRuntimeImpl + + +class ModelContextProtocolToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): + impl = ModelContextProtocolToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py new file mode 100644 index 000000000..ffe4c9887 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class ModelContextProtocolConfig(BaseModel): + pass diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py new file mode 100644 index 000000000..0c6661731 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List +from urllib.parse import urlparse + +from llama_stack.apis.tools import ( + MCPToolGroup, + Tool, + ToolGroup, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate +from mcp import ClientSession +from mcp.client.sse import sse_client + +from .config import ModelContextProtocolConfig + + +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: ModelContextProtocolConfig): + self.config = config + + async def initialize(self): + pass + + async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + if not isinstance(tool_group, MCPToolGroup): + raise ValueError(f"Unsupported tool group type: {type(tool_group)}") + + tools = [] + async with sse_client(tool_group.endpoint.uri) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get( + "properties", {} + ).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + Tool( + identifier=tool.name, + description=tool.description, + tool_group=tool_group.name, + parameters=parameters, + metadata={ + "endpoint": tool_group.endpoint.uri, + }, + provider_resource_id=tool.name, + ) + ) + return tools + + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + tool = await self.tool_store.get_tool(tool_id) + if tool.metadata is None or tool.metadata.get("endpoint") is None: + raise ValueError(f"Tool {tool_id} does not have metadata") + endpoint = tool.metadata.get("endpoint") + if urlparse(endpoint).scheme not in ("http", "https"): + raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") + + async with sse_client(endpoint) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.call_tool(tool.identifier, args) + + return ToolInvocationResult( + content="\n".join([result.model_dump_json() for result in result.content]), + error_code=1 if result.isError else 0, + )