fix: MCP authorization parameter implementation (#4052)

# What does this PR do?
Adding a user-facing `authorization ` parameter to MCP tool definitions
that allows users to explicitly configure credentials per MCP server,
addressing GitHub Issue #4034 in a secure manner.


## Test Plan
tests/integration/responses/test_mcp_authentication.py

---------

Co-authored-by: Omar Abdelwahab <omara@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Omar Abdelwahab 2025-11-14 08:54:42 -08:00 committed by GitHub
parent dc49ad3f89
commit eb545034ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 5205 additions and 62 deletions

View file

@ -34,16 +34,16 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
authorization=authorization,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)

View file

@ -49,7 +49,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
async def list_tools(
self, toolgroup_id: str | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
@ -61,7 +63,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
try:
await self._index_tools(toolgroup)
await self._index_tools(toolgroup, authorization=authorization)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
@ -76,9 +78,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
tooldefs_response = await provider_impl.list_runtime_tools(
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
)
tooldefs = tooldefs_response.data
for t in tooldefs:

View file

@ -257,6 +257,19 @@ class OpenAIResponsesImpl:
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Validate MCP tools: ensure Authorization header is not passed via headers dict
if tools:
from llama_stack_api.openai_responses import OpenAIResponseInputToolMCP
for tool in tools:
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.headers:
for key in tool.headers.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
if conversation is not None:

View file

@ -1091,10 +1091,12 @@ class StreamingResponseOrchestrator:
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}
# List MCP tools with authorization from tool config
async with tracing.span("list_mcp_tools", attributes):
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
headers=mcp_tool.headers,
authorization=mcp_tool.authorization,
)
# Create the MCP list tools message

View file

@ -296,12 +296,14 @@ class ToolExecutor:
"server_url": mcp_tool.server_url,
"tool_name": function_name,
}
# Invoke MCP tool with authorization from tool config
async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
headers=mcp_tool.headers,
authorization=mcp_tool.authorization,
)
elif function_name == "knowledge_search":
response_file_search_tool = (

View file

@ -276,7 +276,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
@ -304,7 +307,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
vector_store_ids = kwargs.get("vector_store_ids", [])
query_config = kwargs.get("query_config")
if query_config:

View file

@ -49,7 +49,10 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
return provider_data.bing_search_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -70,7 +73,9 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
"Ocp-Apim-Subscription-Key": api_key,

View file

@ -48,7 +48,10 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
return provider_data.brave_search_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -70,7 +73,9 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
headers = {

View file

@ -10,8 +10,14 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
"""
Validator for MCP provider-specific data passed via request headers.
Phase 1: Support old header-based authentication for backward compatibility.
In Phase 2, this will be deprecated in favor of the authorization parameter.
"""
mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
class MCPProviderConfig(BaseModel):

View file

@ -39,15 +39,29 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
)
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
@ -55,19 +69,57 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=provider_headers,
authorization=final_authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
"""
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
Returns:
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
"""
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(values)
return headers
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
# (Phase 2 will reject this and require the dedicated authorization parameter)
for key in values.keys():
if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present
auth_value = values[key]
if auth_value.startswith("Bearer "):
authorization = auth_value[7:] # Remove "Bearer " prefix
else:
authorization = auth_value
else:
headers[key] = values[key]
return headers, authorization

View file

@ -48,7 +48,10 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
return provider_data.tavily_search_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -69,7 +72,9 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
async with httpx.AsyncClient() as client:
response = await client.post(

View file

@ -49,7 +49,10 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
return provider_data.wolfram_alpha_api_key
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -70,7 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": kwargs["query"],

View file

@ -30,6 +30,40 @@ from llama_stack_api import (
logger = get_logger(__name__, category="tools")
def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]:
"""
Prepare headers for MCP requests with authorization support.
Args:
base_headers: Base headers dictionary (can be None)
authorization: OAuth access token (without "Bearer " prefix)
Returns:
Headers dictionary with Authorization header if token provided
Raises:
ValueError: If Authorization header is specified in the headers dict (security risk)
"""
headers = dict(base_headers or {})
# Security check: reject any Authorization header in the headers dict
# Users must use the authorization parameter instead to avoid security risks
existing_keys_lower = {k.lower() for k in headers.keys()}
if "authorization" in existing_keys_lower:
raise ValueError(
"For security reasons, Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
# Add Authorization header if token provided
if authorization:
# OAuth access token - add "Bearer " prefix
headers["Authorization"] = f"Bearer {authorization}"
return headers
protocol_cache = TTLDict(ttl_seconds=3600)
@ -112,9 +146,29 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
async def list_mcp_tools(
endpoint: str,
headers: dict[str, str] | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
"""List tools available from an MCP server.
Args:
endpoint: MCP server endpoint URL
headers: Optional base headers to include
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
Returns:
List of tool definitions from the MCP server
Raises:
ValueError: If Authorization is found in the headers parameter
"""
# Prepare headers with authorization handling
final_headers = prepare_mcp_headers(headers, authorization)
tools = []
async with client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, final_headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
tools.append(
@ -132,9 +186,31 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
endpoint: str,
tool_name: str,
kwargs: dict[str, Any],
headers: dict[str, str] | None = None,
authorization: str | None = None,
) -> ToolInvocationResult:
async with client_wrapper(endpoint, headers) as session:
"""Invoke an MCP tool with the given arguments.
Args:
endpoint: MCP server endpoint URL
tool_name: Name of the tool to invoke
kwargs: Tool invocation arguments
headers: Optional base headers to include
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
Returns:
Tool invocation result with content and error information
Raises:
ValueError: If Authorization header is found in the headers parameter
"""
# Prepare headers with authorization handling
final_headers = prepare_mcp_headers(headers, authorization)
async with client_wrapper(endpoint, final_headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []

View file

@ -609,14 +609,14 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
):
"""Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, tool_name, kwargs)
return await original_method(self, tool_name, kwargs, authorization=authorization)
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
@ -634,7 +634,7 @@ async def _patched_tool_invoke_method(
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
# Make the tool call and record it
result = await original_method(self, tool_name, kwargs)
result = await original_method(self, tool_name, kwargs, authorization=authorization)
request_data = {
"test_id": get_test_context(),
@ -885,9 +885,11 @@ def patch_inference_clients():
OllamaAsyncClient.list = patched_ollama_list
# Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
async def patched_tavily_invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
):
return await _patched_tool_invoke_method(
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs, authorization=authorization
)
# Apply tool runtime patches