forked from phoenix-oss/llama-stack-mirror
feat: enable MCP execution in Responses impl (#2240)
## Test Plan ``` pytest -s -v 'tests/verifications/openai_api/test_responses.py' \ --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
66f09f24ed
commit
3faf1e4a79
15 changed files with 865 additions and 382 deletions
|
@ -4,53 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import exceptiongroup
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(e) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
tools = []
|
||||
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
|
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
headers = await self.get_headers_from_request(endpoint)
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, mcp_types.TextContent):
|
||||
content.append(TextContentItem(text=item.text))
|
||||
elif isinstance(item, mcp_types.ImageContent):
|
||||
content.append(ImageContentItem(image=item.data))
|
||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(item)}")
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||
def canonicalize_uri(uri: str) -> str:
|
||||
|
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
for entry in values:
|
||||
parts = entry.split(":")
|
||||
if len(parts) == 2:
|
||||
k, v = parts
|
||||
headers[k.strip()] = v.strip()
|
||||
headers.update(convert_header_list_to_dict(values))
|
||||
return headers
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue