From 8783255bc34a75c340c1854f87f6faa30aa8a36e Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Thu, 13 Nov 2025 10:26:39 -0800 Subject: [PATCH] feat(tool-runtime): Add authorization parameter with backward compatibility Implement Phase 1 of MCP auth migration: - Add authorization parameter to list_runtime_tools() and invoke_tool() - Maintain backward compatibility with X-LlamaStack-Provider-Data header - Tests use old header-based auth to avoid client SDK dependency - New parameter takes precedence when both methods provided Phase 2 will migrate tests to new parameter after Stainless SDK release. Related: PR #4052 --- .../model_context_protocol/config.py | 9 +- .../model_context_protocol.py | 75 ++++++++++---- tests/integration/tool_runtime/test_mcp.py | 19 +++- .../tool_runtime/test_mcp_json_schema.py | 98 ++++++++++++++++--- 4 files changed, 161 insertions(+), 40 deletions(-) diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index f2ae0c00b..9acabfc34 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -13,14 +13,11 @@ class MCPProviderDataValidator(BaseModel): """ Validator for MCP provider-specific data passed via request headers. - Note: MCP authentication and headers are now configured via the request body - (OpenAIResponseInputToolMCP.authorization and .headers fields) rather than - via provider data to simplify the API and avoid multiple configuration paths. - - This validator is kept for future provider-data extensions if needed. + Phase 1: Support old header-based authentication for backward compatibility. + In Phase 2, this will be deprecated in favor of the authorization parameter. """ - pass + mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict class MCPProviderConfig(BaseModel): diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 337a30415..3ef3e055e 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -48,9 +48,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - # Use authorization parameter for MCP servers that require auth - headers = {} - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization) + # Phase 1: Support both old header-based auth AND new authorization parameter + # Get headers and auth from provider data (old approach) + provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri) + + # New authorization parameter takes precedence over provider data + final_authorization = authorization or provider_auth + + return await list_mcp_tools( + endpoint=mcp_endpoint.uri, + headers=provider_headers, + authorization=final_authorization + ) async def invoke_tool( self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None @@ -62,30 +71,60 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - # Authorization now comes from request body parameter (not provider-data) - headers = {} + # Phase 1: Support both old header-based auth AND new authorization parameter + # Get headers and auth from provider data (old approach) + provider_headers, provider_auth = await self.get_headers_from_request(endpoint) + + # New authorization parameter takes precedence over provider data + final_authorization = authorization or provider_auth + return await invoke_mcp_tool( endpoint=endpoint, tool_name=tool_name, kwargs=kwargs, - headers=headers, - authorization=authorization, + headers=provider_headers, + authorization=final_authorization, ) async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: """ - Placeholder method for extracting headers and authorization. + Extract headers and authorization from request provider data (Phase 1 backward compatibility). - Note: MCP authentication and headers are now configured via the request body - (OpenAIResponseInputToolMCP.authorization and .headers fields) and are handled - by the responses API layer, not at the provider level. - - This method is kept for interface compatibility but returns empty values - as the tool runtime provider no longer extracts per-request configuration. + For security, Authorization should not be passed via mcp_headers. + Instead, use a dedicated authorization field in the provider data. Returns: - Tuple of (empty_headers_dict, None) + Tuple of (headers_dict, authorization_token) + - headers_dict: All headers except Authorization + - authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None + + Raises: + ValueError: If Authorization header is found in mcp_headers (security risk) """ - # Headers and authorization are now handled at the responses API layer - # via OpenAIResponseInputToolMCP.headers and .authorization fields - return {}, None + + def canonicalize_uri(uri: str) -> str: + return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" + + headers = {} + authorization = None + + provider_data = self.get_request_provider_data() + if provider_data and hasattr(provider_data, 'mcp_headers') and provider_data.mcp_headers: + for uri, values in provider_data.mcp_headers.items(): + if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): + continue + + # Security check: reject Authorization header in mcp_headers + # This prevents accidentally passing inference tokens to MCP servers + for key in values.keys(): + if key.lower() == "authorization": + # Extract authorization token and strip "Bearer " prefix if present + auth_value = values[key] + if auth_value.startswith("Bearer "): + authorization = auth_value[7:] # Remove "Bearer " prefix + else: + authorization = auth_value + else: + headers[key] = values[key] + + return headers, authorization diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index e576e2770..1b7f509d2 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json + import pytest from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta @@ -35,13 +37,24 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server): mcp_endpoint=dict(uri=uri), ) - # Authorization now passed as request body parameter (not provider-data) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + with pytest.raises(Exception, match="Unauthorized"): llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) tools_list = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, - authorization=AUTH_TOKEN, # Pass authorization as parameter + extra_headers=auth_headers, # Use old header-based approach ) assert len(tools_list) == 2 assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"} @@ -49,7 +62,7 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server): response = llama_stack_client.tool_runtime.invoke_tool( tool_name="greet_everyone", kwargs=dict(url="https://www.google.com"), - authorization=AUTH_TOKEN, # Pass authorization as parameter + extra_headers=auth_headers, # Use old header-based approach ) content = response.content assert len(content) == 1 diff --git a/tests/integration/tool_runtime/test_mcp_json_schema.py b/tests/integration/tool_runtime/test_mcp_json_schema.py index 567380244..719588c7f 100644 --- a/tests/integration/tool_runtime/test_mcp_json_schema.py +++ b/tests/integration/tool_runtime/test_mcp_json_schema.py @@ -4,11 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -""" -Integration tests for MCP tools with complex JSON Schema support. +"""Integration tests for MCP tools with complex JSON Schema support. Tests $ref, $defs, and other JSON Schema features through MCP integration. """ +import json + import pytest from llama_stack.core.library_client import LlamaStackAsLibraryClient @@ -121,10 +122,22 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + # List runtime tools response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) tools = response @@ -160,10 +173,22 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + # List tools response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) # Find book_flight tool (which should have $ref/$defs) @@ -205,9 +230,21 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) # Find get_weather tool @@ -247,10 +284,22 @@ class TestMCPToolInvocation: mcp_endpoint=dict(uri=uri), ) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + # List tools to populate the tool index llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) # Invoke tool with complex nested data @@ -262,7 +311,7 @@ class TestMCPToolInvocation: "shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}}, } }, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) # Should succeed without schema validation errors @@ -288,17 +337,29 @@ class TestMCPToolInvocation: mcp_endpoint=dict(uri=uri), ) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + # List tools to populate the tool index llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) # Test with email format result_email = llama_stack_client.tool_runtime.invoke_tool( tool_name="flexible_contact", kwargs={"contact_info": "user@example.com"}, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) assert result_email.error_message is None @@ -307,7 +368,7 @@ class TestMCPToolInvocation: result_phone = llama_stack_client.tool_runtime.invoke_tool( tool_name="flexible_contact", kwargs={"contact_info": "+15551234567"}, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) assert result_phone.error_message is None @@ -339,9 +400,21 @@ class TestAgentWithMCPTools: mcp_endpoint=dict(uri=uri), ) + # Use old header-based approach for Phase 1 (backward compatibility) + provider_data = { + "mcp_headers": { + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + tools_list = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, - authorization=AUTH_TOKEN, + extra_headers=auth_headers, ) tool_defs = [ { @@ -350,6 +423,7 @@ class TestAgentWithMCPTools: "server_label": test_toolgroup_id, "require_approval": "never", "allowed_tools": [tool.name for tool in tools_list], + "authorization": AUTH_TOKEN, } ] @@ -358,7 +432,6 @@ class TestAgentWithMCPTools: model=text_model_id, instructions="You are a helpful assistant that can process orders and book flights.", tools=tool_defs, - authorization=AUTH_TOKEN, ) session_id = agent.create_session("test-session-complex") @@ -380,7 +453,6 @@ class TestAgentWithMCPTools: } ], stream=True, - authorization=AUTH_TOKEN, ) )