From 84baa5c40606eb5daaf06680ed63626521d1468f Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Wed, 12 Nov 2025 14:41:00 -0800 Subject: [PATCH] feat: unify MCP authentication across Responses and Tool Runtime APIs - Add authorization parameter to Tool Runtime API signatures (list_runtime_tools, invoke_tool) - Update MCP provider implementation to use authorization from request body instead of provider-data - Deprecate mcp_authorization and mcp_headers from provider-data (MCPProviderDataValidator now empty) - Update all Tool Runtime tests to pass authorization as request body parameter - Responses API already uses request body authorization (no changes needed) This provides a single, consistent way to pass MCP authentication tokens across both APIs, addressing reviewer feedback about avoiding multiple configuration paths. --- src/llama_stack/apis/tools/tools.py | 14 ++- .../model_context_protocol/config.py | 37 ++------ .../model_context_protocol.py | 87 +++++++------------ .../inference/test_tools_with_schemas.py | 8 +- tests/integration/tool_runtime/test_mcp.py | 17 +--- .../tool_runtime/test_mcp_json_schema.py | 58 ++++++------- 6 files changed, 87 insertions(+), 134 deletions(-) diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index 4e7cf2544..06580dc74 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -196,22 +196,32 @@ class ToolRuntime(Protocol): # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1) async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: """List all tools in the runtime. :param tool_group_id: The ID of the tool group to list tools for. :param mcp_endpoint: The MCP endpoint to use for the tool group. + :param authorization: (Optional) OAuth access token for authenticating with the MCP server. :returns: A ListToolDefsResponse. """ ... @webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, + tool_name: str, + kwargs: dict[str, Any], + authorization: str | None = None, + ) -> ToolInvocationResult: """Run a tool with the given arguments. :param tool_name: The name of the tool to invoke. :param kwargs: A dictionary of arguments to pass to the tool. + :param authorization: (Optional) OAuth access token for authenticating with the MCP server. :returns: A ToolInvocationResult. """ ... diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index 265fd9918..290b13c26 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -6,41 +6,20 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel class MCPProviderDataValidator(BaseModel): """ Validator for MCP provider-specific data passed via request headers. - Example usage: - HTTP Request Headers: - X-LlamaStack-Provider-Data: { - "mcp_headers": { - "http://weather-mcp.com": { - "X-Trace-ID": "trace-123", - "X-Request-ID": "req-456" - } - }, - "mcp_authorization": { - "http://weather-mcp.com": "weather_api_token_xyz" - } - } - Security Note: - - Authorization header MUST NOT be placed in mcp_headers - - Use the dedicated mcp_authorization field instead - - Each MCP endpoint can have its own separate token + + Note: MCP authentication and headers are now configured via the request body + (OpenAIResponseInputToolMCP.authorization and .headers fields) rather than + via provider data to simplify the API and avoid multiple configuration paths. + + This validator is kept for future provider-data extensions if needed. """ - - # mcp_endpoint => dict of headers to send (excluding Authorization) - mcp_headers: dict[str, dict[str, str]] | None = None - - # mcp_endpoint => authorization token - # Example: {"http://server.com": "token123"} - # Security: exclude=True ensures this field is NEVER included in: - # - API responses - # - Logs - # - Serialization (model_dump, dict(), json()) - mcp_authorization: dict[str, str] | None = Field(default=None, exclude=True) + pass class MCPProviderConfig(BaseModel): diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 506aadf82..137effb33 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -25,7 +25,9 @@ from .config import MCPProviderConfig logger = get_logger(__name__, category="tools") -class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class ModelContextProtocolToolRuntimeImpl( + ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config @@ -39,15 +41,23 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime return async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization) + # Authorization now comes from request body parameter (not provider-data) + headers = {} + return await list_mcp_tools( + endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization + ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None + ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: raise ValueError(f"Tool {tool_name} does not have metadata") @@ -55,7 +65,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - headers, authorization = await self.get_headers_from_request(endpoint) + # Authorization now comes from request body parameter (not provider-data) + headers = {} return await invoke_mcp_tool( endpoint=endpoint, tool_name=tool_name, @@ -64,58 +75,22 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime authorization=authorization, ) - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: + async def get_headers_from_request( + self, mcp_endpoint_uri: str + ) -> tuple[dict[str, str], str | None]: """ - Extract headers and authorization from request provider data. + Placeholder method for extracting headers and authorization. - For security, Authorization should not be passed via mcp_headers. - Instead, use a dedicated authorization field in the provider data. + Note: MCP authentication and headers are now configured via the request body + (OpenAIResponseInputToolMCP.authorization and .headers fields) and are handled + by the responses API layer, not at the provider level. + + This method is kept for interface compatibility but returns empty values + as the tool runtime provider no longer extracts per-request configuration. Returns: - Tuple of (headers_dict, authorization_token) - - headers_dict: All headers except Authorization - - authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None - - Raises: - ValueError: If Authorization header is found in mcp_headers (security risk) + Tuple of (empty_headers_dict, None) """ - - def canonicalize_uri(uri: str) -> str: - return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" - - headers = {} - authorization = None - - # PRIMARY SECURITY: This line prevents inference token leakage - # provider_data only contains X-LlamaStack-Provider-Data (request body), - # never the HTTP Authorization header (which contains the inference token) - provider_data = self.get_request_provider_data() - if provider_data: - # Extract headers (excluding Authorization) - if provider_data.mcp_headers: - for uri, values in provider_data.mcp_headers.items(): - if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): - continue - - # Security check: reject Authorization header in mcp_headers - # This enforces using the dedicated mcp_authorization field for auth tokens - # Note: Inference tokens are already isolated by line 89 (provider_data only contains request body) - for key in values.keys(): - if key.lower() == "authorization": - raise ValueError( - "Authorization header cannot be passed via 'mcp_headers'. " - "Please use 'mcp_authorization' in provider_data instead." - ) - - # Collect all headers (Authorization already rejected above) - headers.update(values) - - # Extract authorization from dedicated field - if provider_data.mcp_authorization: - canonical_endpoint = canonicalize_uri(mcp_endpoint_uri) - for uri, token in provider_data.mcp_authorization.items(): - if canonicalize_uri(uri) == canonical_endpoint: - authorization = token - break - - return headers, authorization + # Headers and authorization are now handled at the responses API layer + # via OpenAIResponseInputToolMCP.headers and .authorization fields + return {}, None diff --git a/tests/integration/inference/test_tools_with_schemas.py b/tests/integration/inference/test_tools_with_schemas.py index 9a3ac0bf0..116e8ff4c 100644 --- a/tests/integration/inference/test_tools_with_schemas.py +++ b/tests/integration/inference/test_tools_with_schemas.py @@ -193,15 +193,15 @@ class TestMCPToolsInChatCompletion: mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } # Get the tools from MCP tools_response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Convert to OpenAI format for inference diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index 8138f0d92..0d08e5a35 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json - import pytest from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta @@ -42,21 +40,13 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server): mcp_endpoint=dict(uri=uri), ) - provider_data = { - "mcp_authorization": { - uri: AUTH_TOKEN, # Token - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - + # Authorization now passed as request body parameter (not provider-data) with pytest.raises(Exception, match="Unauthorized"): llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) tools_list = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, # Pass authorization as parameter ) assert len(tools_list) == 2 assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"} @@ -64,7 +54,7 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server): response = llama_stack_client.tool_runtime.invoke_tool( tool_name="greet_everyone", kwargs=dict(url="https://www.google.com"), - extra_headers=auth_headers, + authorization=AUTH_TOKEN, # Pass authorization as parameter ) content = response.content assert len(content) == 1 @@ -105,7 +95,6 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server): } ], stream=True, - extra_headers=auth_headers, ) ) events = [chunk.event for chunk in chunks] diff --git a/tests/integration/tool_runtime/test_mcp_json_schema.py b/tests/integration/tool_runtime/test_mcp_json_schema.py index 6302fa385..62e9844b4 100644 --- a/tests/integration/tool_runtime/test_mcp_json_schema.py +++ b/tests/integration/tool_runtime/test_mcp_json_schema.py @@ -123,15 +123,15 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } # List runtime tools response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) tools = response @@ -166,15 +166,15 @@ class TestMCPSchemaPreservation: provider_id="model-context-protocol", mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } # List tools response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Find book_flight tool (which should have $ref/$defs) @@ -216,14 +216,14 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Find get_weather tool @@ -263,15 +263,15 @@ class TestMCPToolInvocation: mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } # List tools to populate the tool index llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Invoke tool with complex nested data @@ -283,7 +283,7 @@ class TestMCPToolInvocation: "shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}}, } }, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Should succeed without schema validation errors @@ -309,22 +309,22 @@ class TestMCPToolInvocation: mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } # List tools to populate the tool index llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Test with email format result_email = llama_stack_client.tool_runtime.invoke_tool( tool_name="flexible_contact", kwargs={"contact_info": "user@example.com"}, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) assert result_email.error_message is None @@ -333,7 +333,7 @@ class TestMCPToolInvocation: result_phone = llama_stack_client.tool_runtime.invoke_tool( tool_name="flexible_contact", kwargs={"contact_info": "+15551234567"}, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) assert result_phone.error_message is None @@ -365,14 +365,14 @@ class TestAgentWithMCPTools: mcp_endpoint=dict(uri=uri), ) - provider_data = {"mcp_authorization": {uri: AUTH_TOKEN}} # Token without "Bearer " prefix - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), + # Authorization now passed as request body parameter + # Removed auth_headers - using authorization parameter instead + # (no longer needed) } tools_list = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) tool_defs = [ { @@ -389,7 +389,7 @@ class TestAgentWithMCPTools: model=text_model_id, instructions="You are a helpful assistant that can process orders and book flights.", tools=tool_defs, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) session_id = agent.create_session("test-session-complex") @@ -411,7 +411,7 @@ class TestAgentWithMCPTools: } ], stream=True, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) )