feat(tool-runtime): Add authorization parameter with backward compatibility

Implement Phase 1 of MCP auth migration:
- Add authorization parameter to list_runtime_tools() and invoke_tool()
- Maintain backward compatibility with X-LlamaStack-Provider-Data header
- Tests use old header-based auth to avoid client SDK dependency
- New parameter takes precedence when both methods provided

Phase 2 will migrate tests to new parameter after Stainless SDK release.

Related: PR #4052
This commit is contained in:
Omar Abdelwahab 2025-11-13 10:26:39 -08:00
parent fa2b361f46
commit 8783255bc3
4 changed files with 161 additions and 40 deletions

View file

@ -13,14 +13,11 @@ class MCPProviderDataValidator(BaseModel):
""" """
Validator for MCP provider-specific data passed via request headers. Validator for MCP provider-specific data passed via request headers.
Note: MCP authentication and headers are now configured via the request body Phase 1: Support old header-based authentication for backward compatibility.
(OpenAIResponseInputToolMCP.authorization and .headers fields) rather than In Phase 2, this will be deprecated in favor of the authorization parameter.
via provider data to simplify the API and avoid multiple configuration paths.
This validator is kept for future provider-data extensions if needed.
""" """
pass mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
class MCPProviderConfig(BaseModel): class MCPProviderConfig(BaseModel):

View file

@ -48,9 +48,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
# Use authorization parameter for MCP servers that require auth # Phase 1: Support both old header-based auth AND new authorization parameter
headers = {} # Get headers and auth from provider data (old approach)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization) provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await list_mcp_tools(
endpoint=mcp_endpoint.uri,
headers=provider_headers,
authorization=final_authorization
)
async def invoke_tool( async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@ -62,30 +71,60 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"): if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
# Authorization now comes from request body parameter (not provider-data) # Phase 1: Support both old header-based auth AND new authorization parameter
headers = {} # Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await invoke_mcp_tool( return await invoke_mcp_tool(
endpoint=endpoint, endpoint=endpoint,
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
headers=headers, headers=provider_headers,
authorization=authorization, authorization=final_authorization,
) )
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
""" """
Placeholder method for extracting headers and authorization. Extract headers and authorization from request provider data (Phase 1 backward compatibility).
Note: MCP authentication and headers are now configured via the request body For security, Authorization should not be passed via mcp_headers.
(OpenAIResponseInputToolMCP.authorization and .headers fields) and are handled Instead, use a dedicated authorization field in the provider data.
by the responses API layer, not at the provider level.
This method is kept for interface compatibility but returns empty values
as the tool runtime provider no longer extracts per-request configuration.
Returns: Returns:
Tuple of (empty_headers_dict, None) Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
Raises:
ValueError: If Authorization header is found in mcp_headers (security risk)
""" """
# Headers and authorization are now handled at the responses API layer
# via OpenAIResponseInputToolMCP.headers and .authorization fields def canonicalize_uri(uri: str) -> str:
return {}, None return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and hasattr(provider_data, 'mcp_headers') and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Security check: reject Authorization header in mcp_headers
# This prevents accidentally passing inference tokens to MCP servers
for key in values.keys():
if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present
auth_value = values[key]
if auth_value.startswith("Bearer "):
authorization = auth_value[7:] # Remove "Bearer " prefix
else:
authorization = auth_value
else:
headers[key] = values[key]
return headers, authorization

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import pytest import pytest
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@ -35,13 +37,24 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Authorization now passed as request body parameter (not provider-data) # Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
with pytest.raises(Exception, match="Unauthorized"): with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
tools_list = llama_stack_client.tools.list( tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
authorization=AUTH_TOKEN, # Pass authorization as parameter extra_headers=auth_headers, # Use old header-based approach
) )
assert len(tools_list) == 2 assert len(tools_list) == 2
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"} assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
@ -49,7 +62,7 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="greet_everyone", tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"), kwargs=dict(url="https://www.google.com"),
authorization=AUTH_TOKEN, # Pass authorization as parameter extra_headers=auth_headers, # Use old header-based approach
) )
content = response.content content = response.content
assert len(content) == 1 assert len(content) == 1

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
""" """Integration tests for MCP tools with complex JSON Schema support.
Integration tests for MCP tools with complex JSON Schema support.
Tests $ref, $defs, and other JSON Schema features through MCP integration. Tests $ref, $defs, and other JSON Schema features through MCP integration.
""" """
import json
import pytest import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -121,10 +122,22 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List runtime tools # List runtime tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
tools = response tools = response
@ -160,10 +173,22 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools # List tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
# Find book_flight tool (which should have $ref/$defs) # Find book_flight tool (which should have $ref/$defs)
@ -205,9 +230,21 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
# Find get_weather tool # Find get_weather tool
@ -247,10 +284,22 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools to populate the tool index # List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
# Invoke tool with complex nested data # Invoke tool with complex nested data
@ -262,7 +311,7 @@ class TestMCPToolInvocation:
"shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}}, "shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}},
} }
}, },
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
# Should succeed without schema validation errors # Should succeed without schema validation errors
@ -288,17 +337,29 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools to populate the tool index # List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
# Test with email format # Test with email format
result_email = llama_stack_client.tool_runtime.invoke_tool( result_email = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact", tool_name="flexible_contact",
kwargs={"contact_info": "user@example.com"}, kwargs={"contact_info": "user@example.com"},
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
assert result_email.error_message is None assert result_email.error_message is None
@ -307,7 +368,7 @@ class TestMCPToolInvocation:
result_phone = llama_stack_client.tool_runtime.invoke_tool( result_phone = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact", tool_name="flexible_contact",
kwargs={"contact_info": "+15551234567"}, kwargs={"contact_info": "+15551234567"},
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
assert result_phone.error_message is None assert result_phone.error_message is None
@ -339,9 +400,21 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
tools_list = llama_stack_client.tools.list( tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
authorization=AUTH_TOKEN, extra_headers=auth_headers,
) )
tool_defs = [ tool_defs = [
{ {
@ -350,6 +423,7 @@ class TestAgentWithMCPTools:
"server_label": test_toolgroup_id, "server_label": test_toolgroup_id,
"require_approval": "never", "require_approval": "never",
"allowed_tools": [tool.name for tool in tools_list], "allowed_tools": [tool.name for tool in tools_list],
"authorization": AUTH_TOKEN,
} }
] ]
@ -358,7 +432,6 @@ class TestAgentWithMCPTools:
model=text_model_id, model=text_model_id,
instructions="You are a helpful assistant that can process orders and book flights.", instructions="You are a helpful assistant that can process orders and book flights.",
tools=tool_defs, tools=tool_defs,
authorization=AUTH_TOKEN,
) )
session_id = agent.create_session("test-session-complex") session_id = agent.create_session("test-session-complex")
@ -380,7 +453,6 @@ class TestAgentWithMCPTools:
} }
], ],
stream=True, stream=True,
authorization=AUTH_TOKEN,
) )
) )