feat: accept MCP authorization headers for MCP toolgroups

This commit is contained in:
Ashwin Bharambe 2025-05-22 14:39:59 -07:00
parent 2708312168
commit 0d67e17a91
7 changed files with 294 additions and 44 deletions

View file

@ -236,6 +236,10 @@ class AuthenticationConfig(BaseModel):
)
class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(str, Enum):
DAY = "day"

View file

@ -261,9 +261,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
headers = options.headers or {}
if self.provider_data:
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data
with request_provider_data_context(headers):

View file

@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec(
adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
),
),
]

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .config import MCPProviderConfig
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
async def get_adapter_impl(config: MCPProviderConfig, _deps):
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
impl = ModelContextProtocolToolRuntimeImpl(config)
impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
await impl.initialize()
return impl

View file

@ -9,7 +9,12 @@ from typing import Any
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => list of headers to send
mcp_headers: dict[str, list[str]] | None = None
class MCPProviderConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -4,13 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
@ -18,13 +23,36 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
async def initialize(self):
@ -33,34 +61,34 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
tools = []
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
@ -71,12 +99,39 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, kwargs)
headers = await self.get_headers_from_request(endpoint)
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
content=content,
error_code=1 if result.isError else 0,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
for entry in values:
parts = entry.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers