From 445135b8cc636582bf98fcf0249d5df743beb37b Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Fri, 7 Nov 2025 11:45:47 -0800 Subject: [PATCH] feat: implement dedicated mcp_authorization field for remote provider Completes the TODO for extracting authorization from a dedicated field. What changed: - Added mcp_authorization field to MCPProviderDataValidator - Updated get_headers_from_request() to extract from mcp_authorization - Authorization is now properly isolated per MCP endpoint API usage example: { "provider_data": { "mcp_headers": { "http://mcp-server.com": { "X-Trace-ID": "trace-123" } }, "mcp_authorization": { "http://mcp-server.com": "mcp_token_xyz789" } } } Security guarantees: - Authorization cannot be in mcp_headers (validation rejects it) - Each MCP endpoint gets its own dedicated token - No cross-service token leakage possible --- .../model_context_protocol/config.py | 2 + .../model_context_protocol.py | 51 ++++++++++++------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index b8c5e77fd..73f891c20 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -12,6 +12,8 @@ from pydantic import BaseModel class MCPProviderDataValidator(BaseModel): # mcp_endpoint => dict of headers to send mcp_headers: dict[str, dict[str, str]] | None = None + # mcp_endpoint => authorization token (without "Bearer " prefix) + mcp_authorization: dict[str, str] | None = None class MCPProviderConfig(BaseModel): diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 5e87a72e0..6ddb23631 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -25,7 +25,9 @@ from .config import MCPProviderConfig logger = get_logger(__name__, category="tools") -class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class ModelContextProtocolToolRuntimeImpl( + ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config @@ -45,9 +47,13 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization) + return await list_mcp_tools( + endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization + ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any] + ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: raise ValueError(f"Tool {tool_name} does not have metadata") @@ -89,24 +95,31 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime authorization = None provider_data = self.get_request_provider_data() - if provider_data and provider_data.mcp_headers: - for uri, values in provider_data.mcp_headers.items(): - if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): - continue + if provider_data: + # Extract headers (excluding Authorization) + if provider_data.mcp_headers: + for uri, values in provider_data.mcp_headers.items(): + if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): + continue - # Security check: reject Authorization header in mcp_headers - # This prevents accidentally passing inference tokens to MCP servers - for key in values.keys(): - if key.lower() == "authorization": - raise ValueError( - "Authorization header cannot be passed via 'mcp_headers'. " - "Please use a dedicated authorization field in provider_data instead." - ) + # Security check: reject Authorization header in mcp_headers + # This prevents accidentally passing inference tokens to MCP servers + for key in values.keys(): + if key.lower() == "authorization": + raise ValueError( + "Authorization header cannot be passed via 'mcp_headers'. " + "Please use 'mcp_authorization' in provider_data instead." + ) - # Collect all headers (Authorization already rejected above) - headers.update(values) + # Collect all headers (Authorization already rejected above) + headers.update(values) - # TODO: Extract authorization from a dedicated field in provider_data - # For now, authorization remains None until the API is updated + # Extract authorization from dedicated field + if provider_data.mcp_authorization: + canonical_endpoint = canonicalize_uri(mcp_endpoint_uri) + for uri, token in provider_data.mcp_authorization.items(): + if canonicalize_uri(uri) == canonical_endpoint: + authorization = token + break return headers, authorization