fix: MCP authorization parameter implementation (#4052)

# What does this PR do?
Adding a user-facing `authorization ` parameter to MCP tool definitions
that allows users to explicitly configure credentials per MCP server,
addressing GitHub Issue #4034 in a secure manner.


## Test Plan
tests/integration/responses/test_mcp_authentication.py

---------

Co-authored-by: Omar Abdelwahab <omara@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Omar Abdelwahab 2025-11-14 08:54:42 -08:00 committed by GitHub
parent dc49ad3f89
commit eb545034ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 5205 additions and 62 deletions

View file

@ -49,7 +49,10 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
return provider_data.bing_search_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -70,7 +73,9 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
"Ocp-Apim-Subscription-Key": api_key,

View file

@ -48,7 +48,10 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
return provider_data.brave_search_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -70,7 +73,9 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
headers = {

View file

@ -10,8 +10,14 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
"""
Validator for MCP provider-specific data passed via request headers.
Phase 1: Support old header-based authentication for backward compatibility.
In Phase 2, this will be deprecated in favor of the authorization parameter.
"""
mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
class MCPProviderConfig(BaseModel):

View file

@ -39,15 +39,29 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
)
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
@ -55,19 +69,57 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=provider_headers,
authorization=final_authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
"""
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
Returns:
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
"""
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(values)
return headers
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
# (Phase 2 will reject this and require the dedicated authorization parameter)
for key in values.keys():
if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present
auth_value = values[key]
if auth_value.startswith("Bearer "):
authorization = auth_value[7:] # Remove "Bearer " prefix
else:
authorization = auth_value
else:
headers[key] = values[key]
return headers, authorization

View file

@ -48,7 +48,10 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
return provider_data.tavily_search_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -69,7 +72,9 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
async with httpx.AsyncClient() as client:
response = await client.post(

View file

@ -49,7 +49,10 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
return provider_data.wolfram_alpha_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -70,7 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": kwargs["query"],