feat: unify MCP authentication across Responses and Tool Runtime APIs

- Add authorization parameter to Tool Runtime API signatures (list_runtime_tools, invoke_tool)
- Update MCP provider implementation to use authorization from request body instead of provider-data
- Deprecate mcp_authorization and mcp_headers from provider-data (MCPProviderDataValidator now empty)
- Update all Tool Runtime tests to pass authorization as request body parameter
- Responses API already uses request body authorization (no changes needed)

This provides a single, consistent way to pass MCP authentication tokens across both APIs, addressing reviewer feedback about avoiding multiple configuration paths.
This commit is contained in:
Omar Abdelwahab 2025-11-12 14:41:00 -08:00
parent 893e186d5c
commit 84baa5c406
6 changed files with 87 additions and 134 deletions

View file

@ -196,22 +196,32 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self,
tool_name: str,
kwargs: dict[str, Any],
authorization: str | None = None,
) -> ToolInvocationResult:
"""Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
:returns: A ToolInvocationResult.
"""
...

View file

@ -6,41 +6,20 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
"""
Validator for MCP provider-specific data passed via request headers.
Example usage:
HTTP Request Headers:
X-LlamaStack-Provider-Data: {
"mcp_headers": {
"http://weather-mcp.com": {
"X-Trace-ID": "trace-123",
"X-Request-ID": "req-456"
}
},
"mcp_authorization": {
"http://weather-mcp.com": "weather_api_token_xyz"
}
}
Security Note:
- Authorization header MUST NOT be placed in mcp_headers
- Use the dedicated mcp_authorization field instead
- Each MCP endpoint can have its own separate token
Note: MCP authentication and headers are now configured via the request body
(OpenAIResponseInputToolMCP.authorization and .headers fields) rather than
via provider data to simplify the API and avoid multiple configuration paths.
This validator is kept for future provider-data extensions if needed.
"""
# mcp_endpoint => dict of headers to send (excluding Authorization)
mcp_headers: dict[str, dict[str, str]] | None = None
# mcp_endpoint => authorization token
# Example: {"http://server.com": "token123"}
# Security: exclude=True ensures this field is NEVER included in:
# - API responses
# - Logs
# - Serialization (model_dump, dict(), json())
mcp_authorization: dict[str, str] | None = Field(default=None, exclude=True)
pass
class MCPProviderConfig(BaseModel):

View file

@ -25,7 +25,9 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class ModelContextProtocolToolRuntimeImpl(
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -39,15 +41,23 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
# Authorization now comes from request body parameter (not provider-data)
headers = {}
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
@ -55,7 +65,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers, authorization = await self.get_headers_from_request(endpoint)
# Authorization now comes from request body parameter (not provider-data)
headers = {}
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
@ -64,58 +75,22 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
authorization=authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
async def get_headers_from_request(
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
"""
Extract headers and authorization from request provider data.
Placeholder method for extracting headers and authorization.
For security, Authorization should not be passed via mcp_headers.
Instead, use a dedicated authorization field in the provider data.
Note: MCP authentication and headers are now configured via the request body
(OpenAIResponseInputToolMCP.authorization and .headers fields) and are handled
by the responses API layer, not at the provider level.
This method is kept for interface compatibility but returns empty values
as the tool runtime provider no longer extracts per-request configuration.
Returns:
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
Raises:
ValueError: If Authorization header is found in mcp_headers (security risk)
Tuple of (empty_headers_dict, None)
"""
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
# PRIMARY SECURITY: This line prevents inference token leakage
# provider_data only contains X-LlamaStack-Provider-Data (request body),
# never the HTTP Authorization header (which contains the inference token)
provider_data = self.get_request_provider_data()
if provider_data:
# Extract headers (excluding Authorization)
if provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Security check: reject Authorization header in mcp_headers
# This enforces using the dedicated mcp_authorization field for auth tokens
# Note: Inference tokens are already isolated by line 89 (provider_data only contains request body)
for key in values.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'mcp_headers'. "
"Please use 'mcp_authorization' in provider_data instead."
)
# Collect all headers (Authorization already rejected above)
headers.update(values)
# Extract authorization from dedicated field
if provider_data.mcp_authorization:
canonical_endpoint = canonicalize_uri(mcp_endpoint_uri)
for uri, token in provider_data.mcp_authorization.items():
if canonicalize_uri(uri) == canonical_endpoint:
authorization = token
break
return headers, authorization
# Headers and authorization are now handled at the responses API layer
# via OpenAIResponseInputToolMCP.headers and .authorization fields
return {}, None

View file

@ -193,15 +193,15 @@ class TestMCPToolsInChatCompletion:
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
# Get the tools from MCP
tools_response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Convert to OpenAI format for inference

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@ -42,21 +40,13 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri),
)
provider_data = {
"mcp_authorization": {
uri: AUTH_TOKEN, # Token
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Authorization now passed as request body parameter (not provider-data)
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN, # Pass authorization as parameter
)
assert len(tools_list) == 2
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
@ -64,7 +54,7 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers,
authorization=AUTH_TOKEN, # Pass authorization as parameter
)
content = response.content
assert len(content) == 1
@ -105,7 +95,6 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
}
],
stream=True,
extra_headers=auth_headers,
)
)
events = [chunk.event for chunk in chunks]

View file

@ -123,15 +123,15 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
# List runtime tools
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
tools = response
@ -166,15 +166,15 @@ class TestMCPSchemaPreservation:
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
# List tools
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Find book_flight tool (which should have $ref/$defs)
@ -216,14 +216,14 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Find get_weather tool
@ -263,15 +263,15 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Invoke tool with complex nested data
@ -283,7 +283,7 @@ class TestMCPToolInvocation:
"shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}},
}
},
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Should succeed without schema validation errors
@ -309,22 +309,22 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Test with email format
result_email = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact",
kwargs={"contact_info": "user@example.com"},
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
assert result_email.error_message is None
@ -333,7 +333,7 @@ class TestMCPToolInvocation:
result_phone = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact",
kwargs={"contact_info": "+15551234567"},
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
assert result_phone.error_message is None
@ -365,14 +365,14 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri),
)
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
# Authorization now passed as request body parameter
# Removed auth_headers - using authorization parameter instead
# (no longer needed)
}
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
tool_defs = [
{
@ -389,7 +389,7 @@ class TestAgentWithMCPTools:
model=text_model_id,
instructions="You are a helpful assistant that can process orders and book flights.",
tools=tool_defs,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
session_id = agent.create_session("test-session-complex")
@ -411,7 +411,7 @@ class TestAgentWithMCPTools:
}
],
stream=True,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
)