mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-20 18:42:26 +00:00
add model context protocol provider
This commit is contained in:
parent
dc21e14f64
commit
33af4f919f
8 changed files with 160 additions and 10 deletions
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
|
@ -42,6 +42,9 @@ class BraveSearchToolRuntimeImpl(
|
|||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
|
@ -19,4 +25,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||
pip_packages=["mcp"],
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
|
||||
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
pass
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroup,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: ModelContextProtocolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
if not isinstance(tool_group, MCPToolGroup):
|
||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
||||
|
||||
tools = []
|
||||
async with sse_client(tool_group.endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get(
|
||||
"properties", {}
|
||||
).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool.name,
|
||||
description=tool.description,
|
||||
tool_group=tool_group.name,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": tool_group.endpoint.uri,
|
||||
},
|
||||
provider_resource_id=tool.name,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_id)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_id} does not have metadata")
|
||||
endpoint = tool.metadata.get("endpoint")
|
||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, args)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue