add inline mcp provider

This commit is contained in:
Dinesh Yeduguru 2025-01-08 23:11:43 -08:00
parent ffc6bd4805
commit 2c265d803c
16 changed files with 398 additions and 49 deletions

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -57,18 +57,35 @@ class ToolDef(BaseModel):
)
@json_schema_type
class MCPInlineConfig(BaseModel):
type: Literal["inline"] = "inline"
command: str
args: Optional[List[str]] = None
env: Optional[Dict[str, Any]] = None
@json_schema_type
class MCPRemoteConfig(BaseModel):
type: Literal["remote"] = "remote"
mcp_endpoint: URL
MCPConfig = register_schema(Union[MCPInlineConfig, MCPRemoteConfig], name="MCPConfig")
@json_schema_type
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
mcp_config: Optional[MCPConfig] = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
mcp_config: Optional[MCPConfig] = None
args: Optional[Dict[str, Any]] = None
@ -92,7 +109,7 @@ class ToolGroups(Protocol):
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
mcp_config: Optional[MCPConfig] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
@ -131,7 +148,9 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.apis.tools import MCPConfig, ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
@ -418,8 +418,10 @@ class ToolRuntimeRouter(ToolRuntime):
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
tool_group_id, mcp_config
)

View file

@ -26,7 +26,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.apis.tools import MCPConfig, Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -504,15 +504,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
mcp_config: Optional[MCPConfig] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
toolgroup_id, mcp_config
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
ToolHost.model_context_protocol if mcp_config else ToolHost.distribution
)
for tool_def in tool_defs:
@ -547,7 +547,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
mcp_config=mcp_config,
args=args,
)
)

View file

@ -9,8 +9,8 @@ import logging
import tempfile
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@ -43,7 +43,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(

View file

@ -10,11 +10,11 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
MCPConfig,
ToolDef,
ToolInvocationResult,
ToolParameter,
@ -52,7 +52,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: ModelContextProtocolConfig, _deps):
impl = ModelContextProtocolToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
pass

View file

@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from mcp import ClientSession
from mcp.client.stdio import stdio_client, StdioServerParameters
from pydantic import TypeAdapter
from llama_stack.apis.tools import (
MCPConfig,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
self.config = config
async def initialize(self):
pass
async def list_runtime_tools(
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
if mcp_config is None:
raise ValueError("mcp_config is required")
tools = []
async with stdio_client(
StdioServerParameters(
command=mcp_config.command,
args=mcp_config.args,
env=mcp_config.env,
)
) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get(
"properties", {}
).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"mcp_config": mcp_config.model_dump_json(),
},
)
)
return tools
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("mcp_config") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
async with stdio_client(
StdioServerParameters(
command=mcp_config.command,
args=mcp_config.args,
env=mcp_config.env,
)
) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, arguments=args)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
error_code=1 if result.isError else 0,
)

View file

@ -32,6 +32,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
),
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::model-context-protocol",
pip_packages=["mcp"],
module="llama_stack.providers.inline.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.inline.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@ -50,7 +50,9 @@ class BraveSearchToolRuntimeImpl(
return provider_data.api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPConfig,
ToolDef,
ToolInvocationResult,
ToolParameter,
@ -30,13 +31,15 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
if mcp_config is None:
raise ValueError("mcp_config is required")
tools = []
async with sse_client(mcp_endpoint.uri) as streams:
async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
@ -58,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
"mcp_config": mcp_config,
},
)
)
@ -68,13 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
if tool.metadata is None or tool.metadata.get("mcp_config") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
endpoint = tool.metadata.get("endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
async with sse_client(endpoint) as streams:
async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@ -50,7 +50,9 @@ class TavilySearchToolRuntimeImpl(
return provider_data.api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@ -51,7 +51,9 @@ class WolframAlphaToolRuntimeImpl(
return provider_data.api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(