diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 92a7d788e..5e87a72e0 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -64,14 +64,22 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime authorization=authorization, ) - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: + async def get_headers_from_request( + self, mcp_endpoint_uri: str + ) -> tuple[dict[str, str], str | None]: """ Extract headers and authorization from request provider data. + For security, Authorization should not be passed via mcp_headers. + Instead, use a dedicated authorization field in the provider data. + Returns: Tuple of (headers_dict, authorization_token) - headers_dict: All headers except Authorization - authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None + + Raises: + ValueError: If Authorization header is found in mcp_headers (security risk) """ def canonicalize_uri(uri: str) -> str: @@ -85,12 +93,20 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime for uri, values in provider_data.mcp_headers.items(): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): continue - # Extract Authorization header separately for security - for key, value in values.items(): + + # Security check: reject Authorization header in mcp_headers + # This prevents accidentally passing inference tokens to MCP servers + for key in values.keys(): if key.lower() == "authorization": - # Remove "Bearer " prefix if present - authorization = value.removeprefix("Bearer ").strip() - else: - headers[key] = value + raise ValueError( + "Authorization header cannot be passed via 'mcp_headers'. " + "Please use a dedicated authorization field in provider_data instead." + ) + + # Collect all headers (Authorization already rejected above) + headers.update(values) + + # TODO: Extract authorization from a dedicated field in provider_data + # For now, authorization remains None until the API is updated return headers, authorization