feat: unify MCP authentication across Responses and Tool Runtime APIs

- Add authorization parameter to Tool Runtime API signatures (list_runtime_tools, invoke_tool)
- Update MCP provider implementation to use authorization from request body instead of provider-data
- Deprecate mcp_authorization and mcp_headers from provider-data (MCPProviderDataValidator now empty)
- Update all Tool Runtime tests to pass authorization as request body parameter
- Responses API already uses request body authorization (no changes needed)

This provides a single, consistent way to pass MCP authentication tokens across both APIs, addressing reviewer feedback about avoiding multiple configuration paths.
This commit is contained in:
Omar Abdelwahab 2025-11-12 14:41:00 -08:00
parent 893e186d5c
commit 84baa5c406
6 changed files with 87 additions and 134 deletions

View file

@ -196,22 +196,32 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
"""List all tools in the runtime. """List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for. :param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group. :param mcp_endpoint: The MCP endpoint to use for the tool group.
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
:returns: A ListToolDefsResponse. :returns: A ListToolDefsResponse.
""" """
... ...
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(
self,
tool_name: str,
kwargs: dict[str, Any],
authorization: str | None = None,
) -> ToolInvocationResult:
"""Run a tool with the given arguments. """Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke. :param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool. :param kwargs: A dictionary of arguments to pass to the tool.
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
:returns: A ToolInvocationResult. :returns: A ToolInvocationResult.
""" """
... ...

View file

@ -6,41 +6,20 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel): class MCPProviderDataValidator(BaseModel):
""" """
Validator for MCP provider-specific data passed via request headers. Validator for MCP provider-specific data passed via request headers.
Example usage:
HTTP Request Headers: Note: MCP authentication and headers are now configured via the request body
X-LlamaStack-Provider-Data: { (OpenAIResponseInputToolMCP.authorization and .headers fields) rather than
"mcp_headers": { via provider data to simplify the API and avoid multiple configuration paths.
"http://weather-mcp.com": {
"X-Trace-ID": "trace-123", This validator is kept for future provider-data extensions if needed.
"X-Request-ID": "req-456"
}
},
"mcp_authorization": {
"http://weather-mcp.com": "weather_api_token_xyz"
}
}
Security Note:
- Authorization header MUST NOT be placed in mcp_headers
- Use the dedicated mcp_authorization field instead
- Each MCP endpoint can have its own separate token
""" """
pass
# mcp_endpoint => dict of headers to send (excluding Authorization)
mcp_headers: dict[str, dict[str, str]] | None = None
# mcp_endpoint => authorization token
# Example: {"http://server.com": "token123"}
# Security: exclude=True ensures this field is NEVER included in:
# - API responses
# - Logs
# - Serialization (model_dump, dict(), json())
mcp_authorization: dict[str, str] | None = Field(default=None, exclude=True)
class MCPProviderConfig(BaseModel): class MCPProviderConfig(BaseModel):

View file

@ -25,7 +25,9 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class ModelContextProtocolToolRuntimeImpl(
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
@ -39,15 +41,23 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
return return
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self,
tool_group_id: str | None = None,
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right? # this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) # Authorization now comes from request body parameter (not provider-data)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization) headers = {}
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None: if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata") raise ValueError(f"Tool {tool_name} does not have metadata")
@ -55,7 +65,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"): if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers, authorization = await self.get_headers_from_request(endpoint) # Authorization now comes from request body parameter (not provider-data)
headers = {}
return await invoke_mcp_tool( return await invoke_mcp_tool(
endpoint=endpoint, endpoint=endpoint,
tool_name=tool_name, tool_name=tool_name,
@ -64,58 +75,22 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
authorization=authorization, authorization=authorization,
) )
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: async def get_headers_from_request(
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
""" """
Extract headers and authorization from request provider data. Placeholder method for extracting headers and authorization.
For security, Authorization should not be passed via mcp_headers. Note: MCP authentication and headers are now configured via the request body
Instead, use a dedicated authorization field in the provider data. (OpenAIResponseInputToolMCP.authorization and .headers fields) and are handled
by the responses API layer, not at the provider level.
This method is kept for interface compatibility but returns empty values
as the tool runtime provider no longer extracts per-request configuration.
Returns: Returns:
Tuple of (headers_dict, authorization_token) Tuple of (empty_headers_dict, None)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
Raises:
ValueError: If Authorization header is found in mcp_headers (security risk)
""" """
# Headers and authorization are now handled at the responses API layer
def canonicalize_uri(uri: str) -> str: # via OpenAIResponseInputToolMCP.headers and .authorization fields
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" return {}, None
headers = {}
authorization = None
# PRIMARY SECURITY: This line prevents inference token leakage
# provider_data only contains X-LlamaStack-Provider-Data (request body),
# never the HTTP Authorization header (which contains the inference token)
provider_data = self.get_request_provider_data()
if provider_data:
# Extract headers (excluding Authorization)
if provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Security check: reject Authorization header in mcp_headers
# This enforces using the dedicated mcp_authorization field for auth tokens
# Note: Inference tokens are already isolated by line 89 (provider_data only contains request body)
for key in values.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'mcp_headers'. "
"Please use 'mcp_authorization' in provider_data instead."
)
# Collect all headers (Authorization already rejected above)
headers.update(values)
# Extract authorization from dedicated field
if provider_data.mcp_authorization:
canonical_endpoint = canonicalize_uri(mcp_endpoint_uri)
for uri, token in provider_data.mcp_authorization.items():
if canonicalize_uri(uri) == canonical_endpoint:
authorization = token
break
return headers, authorization

View file

@ -193,15 +193,15 @@ class TestMCPToolsInChatCompletion:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
# Get the tools from MCP # Get the tools from MCP
tools_response = llama_stack_client.tool_runtime.list_tools( tools_response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Convert to OpenAI format for inference # Convert to OpenAI format for inference

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import pytest import pytest
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@ -42,21 +40,13 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = { # Authorization now passed as request body parameter (not provider-data)
"mcp_authorization": {
uri: AUTH_TOKEN, # Token
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
with pytest.raises(Exception, match="Unauthorized"): with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
tools_list = llama_stack_client.tools.list( tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN, # Pass authorization as parameter
) )
assert len(tools_list) == 2 assert len(tools_list) == 2
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"} assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
@ -64,7 +54,7 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="greet_everyone", tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"), kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers, authorization=AUTH_TOKEN, # Pass authorization as parameter
) )
content = response.content content = response.content
assert len(content) == 1 assert len(content) == 1
@ -105,7 +95,6 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
} }
], ],
stream=True, stream=True,
extra_headers=auth_headers,
) )
) )
events = [chunk.event for chunk in chunks] events = [chunk.event for chunk in chunks]

View file

@ -123,15 +123,15 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
# List runtime tools # List runtime tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
tools = response tools = response
@ -166,15 +166,15 @@ class TestMCPSchemaPreservation:
provider_id="model-context-protocol", provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
# List tools # List tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Find book_flight tool (which should have $ref/$defs) # Find book_flight tool (which should have $ref/$defs)
@ -216,14 +216,14 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Find get_weather tool # Find get_weather tool
@ -263,15 +263,15 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
# List tools to populate the tool index # List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Invoke tool with complex nested data # Invoke tool with complex nested data
@ -283,7 +283,7 @@ class TestMCPToolInvocation:
"shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}}, "shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}},
} }
}, },
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Should succeed without schema validation errors # Should succeed without schema validation errors
@ -309,22 +309,22 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
# List tools to populate the tool index # List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Test with email format # Test with email format
result_email = llama_stack_client.tool_runtime.invoke_tool( result_email = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact", tool_name="flexible_contact",
kwargs={"contact_info": "user@example.com"}, kwargs={"contact_info": "user@example.com"},
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
assert result_email.error_message is None assert result_email.error_message is None
@ -333,7 +333,7 @@ class TestMCPToolInvocation:
result_phone = llama_stack_client.tool_runtime.invoke_tool( result_phone = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact", tool_name="flexible_contact",
kwargs={"contact_info": "+15551234567"}, kwargs={"contact_info": "+15551234567"},
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
assert result_phone.error_message is None assert result_phone.error_message is None
@ -365,14 +365,14 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix # Authorization now passed as request body parameter
auth_headers = { # Removed auth_headers - using authorization parameter instead
"X-LlamaStack-Provider-Data": json.dumps(provider_data), # (no longer needed)
} }
tools_list = llama_stack_client.tools.list( tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
tool_defs = [ tool_defs = [
{ {
@ -389,7 +389,7 @@ class TestAgentWithMCPTools:
model=text_model_id, model=text_model_id,
instructions="You are a helpful assistant that can process orders and book flights.", instructions="You are a helpful assistant that can process orders and book flights.",
tools=tool_defs, tools=tool_defs,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
session_id = agent.create_session("test-session-complex") session_id = agent.create_session("test-session-complex")
@ -411,7 +411,7 @@ class TestAgentWithMCPTools:
} }
], ],
stream=True, stream=True,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
) )