feat: accept MCP authorization headers for MCP toolgroups (#2230)

The most interesting MCP servers are those with an authorization wall in
front of them. This PR uses the existing `provider_data` mechanism of
passing provider API keys for passing MCP access tokens (in fact,
arbitrary headers in the style of the OpenAI Responses API) from the
client through to the MCP server.

```
class MCPProviderDataValidator(BaseModel):
    # mcp_endpoint => list of headers to send
    mcp_headers: dict[str, list[str]] | None = None
```

Note how we must stuff the headers for all MCP endpoints into a single
"MCPProviderDataValidator". Unlike existing providers (e.g., Together
and Fireworks for inference) where we could name the provider api keys
clearly (`together_api_key`, `fireworks_api_key`), we cannot name these
keys for MCP. We have a single generic MCP provider which can serve
multiple "toolgroups". So we use a dict to combine all the headers for
all MCP endpoints you may want to use in an agentic call.


## Test Plan

See the added integration test for usage.
This commit is contained in:
Ashwin Bharambe 2025-05-23 08:52:18 -07:00 committed by GitHub
parent 2708312168
commit 51945f1e57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 331 additions and 47 deletions

View file

@ -76,8 +76,8 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
async def get_tool(self, tool_name: str) -> Tool: ...
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):

View file

@ -236,6 +236,10 @@ class AuthenticationConfig(BaseModel):
)
class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(str, Enum):
DAY = "day"

View file

@ -261,9 +261,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
headers = options.headers or {}
if self.provider_data:
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data
with request_provider_data_context(headers):

View file

@ -28,7 +28,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
@ -122,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,

View file

@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec(
adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
),
),
]

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .config import MCPProviderConfig
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
async def get_adapter_impl(config: MCPProviderConfig, _deps):
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
impl = ModelContextProtocolToolRuntimeImpl(config)
impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
await impl.initialize()
return impl

View file

@ -9,7 +9,12 @@ from typing import Any
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => list of headers to send
mcp_headers: dict[str, list[str]] | None = None
class MCPProviderConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -4,13 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
@ -18,13 +23,36 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
async def initialize(self):
@ -33,34 +61,34 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
tools = []
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
@ -71,12 +99,39 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, kwargs)
headers = await self.get_headers_from_request(endpoint)
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
content=content,
error_code=1 if result.isError else 0,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
for entry in values:
parts = entry.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers