mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
add model context protocol provider
This commit is contained in:
parent
dc21e14f64
commit
33af4f919f
8 changed files with 160 additions and 10 deletions
|
@ -26,11 +26,11 @@ class ToolParameter(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||||
name: str
|
|
||||||
tool_group: str
|
tool_group: str
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
)
|
)
|
||||||
|
@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||||
|
name: str
|
||||||
endpoint: URL
|
endpoint: URL
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class UserDefinedToolGroup(BaseModel):
|
class UserDefinedToolGroup(BaseModel):
|
||||||
type: Literal["user_defined"] = "user_defined"
|
type: Literal["user_defined"] = "user_defined"
|
||||||
|
name: str
|
||||||
tools: List[ToolDef]
|
tools: List[ToolDef]
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,7 +89,6 @@ class Tools(Protocol):
|
||||||
@webmethod(route="/toolgroups/register", method="POST")
|
@webmethod(route="/toolgroups/register", method="POST")
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
name: str,
|
|
||||||
tool_group: ToolGroup,
|
tool_group: ToolGroup,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -115,6 +116,9 @@ class Tools(Protocol):
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore
|
||||||
|
|
||||||
|
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||||
|
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_id: str, args: Dict[str, Any]
|
self, tool_id: str, args: Dict[str, Any]
|
||||||
|
|
|
@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||||
|
return await self.routing_table.get_provider_impl(
|
||||||
|
tool_group.name
|
||||||
|
).discover_tools(tool_group)
|
||||||
|
|
|
@ -480,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools):
|
||||||
|
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
name: str,
|
|
||||||
tool_group: ToolGroup,
|
tool_group: ToolGroup,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
if isinstance(tool_group, MCPToolGroup):
|
if isinstance(tool_group, MCPToolGroup):
|
||||||
# TODO: first needs to be resolved to corresponding tools available in the MCP server
|
# TODO: Actually find the right MCP provider
|
||||||
raise NotImplementedError("MCP tool provider not implemented yet")
|
if provider_id is None:
|
||||||
|
raise ValueError("MCP provider_id not specified")
|
||||||
|
tools = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||||
|
tool_group
|
||||||
|
)
|
||||||
|
for tool in tools:
|
||||||
|
tool.provider_id = provider_id
|
||||||
elif isinstance(tool_group, UserDefinedToolGroup):
|
elif isinstance(tool_group, UserDefinedToolGroup):
|
||||||
for tool in tool_group.tools:
|
for tool in tool_group.tools:
|
||||||
|
|
||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
Tool(
|
||||||
identifier=tool.name,
|
identifier=tool.name,
|
||||||
tool_group=name,
|
tool_group=tool_group.name,
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=tool.parameters,
|
parameters=tool.parameters,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
tool_prompt_format=tool.tool_prompt_format,
|
tool_prompt_format=tool.tool_prompt_format,
|
||||||
provider_resource_id=tool.name,
|
provider_resource_id=tool.name,
|
||||||
|
metadata=tool.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
existing_tool = await self.get_tool(tool.name)
|
existing_tool = await self.get_tool(tool.identifier)
|
||||||
# Compare existing and new object if one exists
|
# Compare existing and new object if one exists
|
||||||
if existing_tool:
|
if existing_tool:
|
||||||
# Compare all fields except provider_id since that might be None in new obj
|
# Compare all fields except provider_id since that might be None in new obj
|
||||||
|
|
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
@ -42,6 +42,9 @@ class BraveSearchToolRuntimeImpl(
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||||
|
raise NotImplementedError("Brave search tool group not supported")
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_id: str, args: Dict[str, Any]
|
self, tool_id: str, args: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
|
|
|
@ -6,7 +6,13 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
|
Api,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
@ -19,4 +25,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="model-context-protocol",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||||
|
pip_packages=["mcp"],
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||||
|
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextProtocolConfig(BaseModel):
|
||||||
|
pass
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
MCPToolGroup,
|
||||||
|
Tool,
|
||||||
|
ToolGroup,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(self, config: ModelContextProtocolConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||||
|
if not isinstance(tool_group, MCPToolGroup):
|
||||||
|
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
async with sse_client(tool_group.endpoint.uri) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
tools_result = await session.list_tools()
|
||||||
|
for tool in tools_result.tools:
|
||||||
|
parameters = []
|
||||||
|
for param_name, param_schema in tool.inputSchema.get(
|
||||||
|
"properties", {}
|
||||||
|
).items():
|
||||||
|
parameters.append(
|
||||||
|
ToolParameter(
|
||||||
|
name=param_name,
|
||||||
|
parameter_type=param_schema.get("type", "string"),
|
||||||
|
description=param_schema.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
identifier=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
tool_group=tool_group.name,
|
||||||
|
parameters=parameters,
|
||||||
|
metadata={
|
||||||
|
"endpoint": tool_group.endpoint.uri,
|
||||||
|
},
|
||||||
|
provider_resource_id=tool.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_id: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
tool = await self.tool_store.get_tool(tool_id)
|
||||||
|
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||||
|
raise ValueError(f"Tool {tool_id} does not have metadata")
|
||||||
|
endpoint = tool.metadata.get("endpoint")
|
||||||
|
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||||
|
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||||
|
|
||||||
|
async with sse_client(endpoint) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
result = await session.call_tool(tool.identifier, args)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||||
|
error_code=1 if result.isError else 0,
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue