add model context protocol provider (#665)

# What does this PR do?
Changes:
* Adds a new API to discover tools available on a runtime
* Adds a new model context protocol provider


## Test Plan

```
# clone python sdk for mcp and start the simple-tool server
uv run mcp-simple-tool --transport sse --port 56000

curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \
-H 'Content-Type: application/json' \
-d '{
  "tool_group": { "name": "simple_mcp_group",
    "type": "model_context_protocol",
    "endpoint": {"uri": "http://localhost:56000/sse"}
  },
  "provider_id": "model-context-protocol"
}'

curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' \
-d '{
    "tool_id": "fetch",
    "args": {
        "url": "http://google.com/"
    }
}'

```
This commit is contained in:
Dinesh Yeduguru 2024-12-19 16:14:23 -08:00 committed by GitHub
parent dc21e14f64
commit e95c168bc0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 160 additions and 10 deletions

View file

@ -26,11 +26,11 @@ class ToolParameter(BaseModel):
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
name: str
tool_group: str
description: str
parameters: List[ToolParameter]
provider_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel):
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
name: str
endpoint: URL
@json_schema_type
class UserDefinedToolGroup(BaseModel):
type: Literal["user_defined"] = "user_defined"
name: str
tools: List[ToolDef]
@ -87,7 +89,6 @@ class Tools(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
name: str,
tool_group: ToolGroup,
provider_id: Optional[str] = None,
) -> None:
@ -115,6 +116,9 @@ class Tools(Protocol):
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]

View file

@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime):
tool_id=tool_id,
args=args,
)
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)

View file

@ -480,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools):
async def register_tool_group(
self,
name: str,
tool_group: ToolGroup,
provider_id: Optional[str] = None,
) -> None:
tools = []
if isinstance(tool_group, MCPToolGroup):
# TODO: first needs to be resolved to corresponding tools available in the MCP server
raise NotImplementedError("MCP tool provider not implemented yet")
# TODO: Actually find the right MCP provider
if provider_id is None:
raise ValueError("MCP provider_id not specified")
tools = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
for tool in tools:
tool.provider_id = provider_id
elif isinstance(tool_group, UserDefinedToolGroup):
for tool in tool_group.tools:
tools.append(
Tool(
identifier=tool.name,
tool_group=name,
tool_group=tool_group.name,
name=tool.name,
description=tool.description,
parameters=tool.parameters,
provider_id=provider_id,
tool_prompt_format=tool.tool_prompt_format,
provider_resource_id=tool.name,
metadata=tool.metadata,
)
)
else:
raise ValueError(f"Unknown tool group: {tool_group}")
for tool in tools:
existing_tool = await self.get_tool(tool.name)
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
# Compare all fields except provider_id since that might be None in new obj

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any, Dict, List
import requests
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -42,6 +42,9 @@ class BraveSearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:

View file

@ -6,7 +6,13 @@
from typing import List
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.distribution.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> List[ProviderSpec]:
@ -19,4 +25,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
pip_packages=["mcp"],
),
),
]

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
impl = ModelContextProtocolToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
pass

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from urllib.parse import urlparse
from llama_stack.apis.tools import (
MCPToolGroup,
Tool,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from mcp import ClientSession
from mcp.client.sse import sse_client
from .config import ModelContextProtocolConfig
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
self.config = config
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
if not isinstance(tool_group, MCPToolGroup):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get(
"properties", {}
).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
Tool(
identifier=tool.name,
description=tool.description,
tool_group=tool_group.name,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
},
provider_resource_id=tool.name,
)
)
return tools
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_id)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_id} does not have metadata")
endpoint = tool.metadata.get("endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
error_code=1 if result.isError else 0,
)