feat: implement dedicated mcp_authorization field for remote provider

Completes the TODO for extracting authorization from a dedicated field.

What changed:
- Added mcp_authorization field to MCPProviderDataValidator
- Updated get_headers_from_request() to extract from mcp_authorization
- Authorization is now properly isolated per MCP endpoint

API usage example:
{
  "provider_data": {
    "mcp_headers": {
      "http://mcp-server.com": {
        "X-Trace-ID": "trace-123"
      }
    },
    "mcp_authorization": {
      "http://mcp-server.com": "mcp_token_xyz789"
    }
  }
}

Security guarantees:
- Authorization cannot be in mcp_headers (validation rejects it)
- Each MCP endpoint gets its own dedicated token
- No cross-service token leakage possible
This commit is contained in:
Omar Abdelwahab 2025-11-07 11:45:47 -08:00
parent a842c90059
commit 445135b8cc
2 changed files with 34 additions and 19 deletions

View file

@ -12,6 +12,8 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
# mcp_endpoint => authorization token (without "Bearer " prefix)
mcp_authorization: dict[str, str] | None = None
class MCPProviderConfig(BaseModel):

View file

@ -25,7 +25,9 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class ModelContextProtocolToolRuntimeImpl(
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -45,9 +47,13 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
@ -89,24 +95,31 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
if provider_data:
# Extract headers (excluding Authorization)
if provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Security check: reject Authorization header in mcp_headers
# This prevents accidentally passing inference tokens to MCP servers
for key in values.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'mcp_headers'. "
"Please use a dedicated authorization field in provider_data instead."
)
# Security check: reject Authorization header in mcp_headers
# This prevents accidentally passing inference tokens to MCP servers
for key in values.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'mcp_headers'. "
"Please use 'mcp_authorization' in provider_data instead."
)
# Collect all headers (Authorization already rejected above)
headers.update(values)
# Collect all headers (Authorization already rejected above)
headers.update(values)
# TODO: Extract authorization from a dedicated field in provider_data
# For now, authorization remains None until the API is updated
# Extract authorization from dedicated field
if provider_data.mcp_authorization:
canonical_endpoint = canonicalize_uri(mcp_endpoint_uri)
for uri, token in provider_data.mcp_authorization.items():
if canonicalize_uri(uri) == canonical_endpoint:
authorization = token
break
return headers, authorization