mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 09:04:30 +00:00
add inline mcp provider
This commit is contained in:
parent
ffc6bd4805
commit
2c265d803c
16 changed files with 398 additions and 49 deletions
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
|
|||
import requests
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
MCPConfig,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
|
@ -50,7 +50,9 @@ class BraveSearchToolRuntimeImpl(
|
|||
return provider_data.api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self,
|
||||
tool_group_id: Optional[str] = None,
|
||||
mcp_config: Optional[MCPConfig] = None,
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
MCPConfig,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
|
|
@ -30,13 +31,15 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
pass
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self,
|
||||
tool_group_id: Optional[str] = None,
|
||||
mcp_config: Optional[MCPConfig] = None,
|
||||
) -> List[ToolDef]:
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
if mcp_config is None:
|
||||
raise ValueError("mcp_config is required")
|
||||
|
||||
tools = []
|
||||
async with sse_client(mcp_endpoint.uri) as streams:
|
||||
async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
|
|
@ -58,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
"mcp_config": mcp_config,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -68,13 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
if tool.metadata is None or tool.metadata.get("mcp_config") is None:
|
||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||
endpoint = tool.metadata.get("endpoint")
|
||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
|
||||
mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
|
||||
|
||||
async with sse_client(endpoint) as streams:
|
||||
async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, args)
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
MCPConfig,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
|
@ -50,7 +50,9 @@ class TavilySearchToolRuntimeImpl(
|
|||
return provider_data.api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self,
|
||||
tool_group_id: Optional[str] = None,
|
||||
mcp_config: Optional[MCPConfig] = None,
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
MCPConfig,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
|
@ -51,7 +51,9 @@ class WolframAlphaToolRuntimeImpl(
|
|||
return provider_data.api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self,
|
||||
tool_group_id: Optional[str] = None,
|
||||
mcp_config: Optional[MCPConfig] = None,
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue