feat(tool-runtime): Add authorization parameter with backward compatibility

Implement Phase 1 of MCP auth migration:
- Add authorization parameter to list_runtime_tools() and invoke_tool()
- Maintain backward compatibility with X-LlamaStack-Provider-Data header
- Tests use old header-based auth to avoid client SDK dependency
- New parameter takes precedence when both methods provided

Phase 2 will migrate tests to new parameter after Stainless SDK release.

Related: PR #4052
This commit is contained in:
Omar Abdelwahab 2025-11-13 10:26:39 -08:00
parent fa2b361f46
commit 8783255bc3
4 changed files with 161 additions and 40 deletions

View file

@ -13,14 +13,11 @@ class MCPProviderDataValidator(BaseModel):
"""
Validator for MCP provider-specific data passed via request headers.
Note: MCP authentication and headers are now configured via the request body
(OpenAIResponseInputToolMCP.authorization and .headers fields) rather than
via provider data to simplify the API and avoid multiple configuration paths.
This validator is kept for future provider-data extensions if needed.
Phase 1: Support old header-based authentication for backward compatibility.
In Phase 2, this will be deprecated in favor of the authorization parameter.
"""
pass
mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
class MCPProviderConfig(BaseModel):

View file

@ -48,9 +48,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
# Use authorization parameter for MCP servers that require auth
headers = {}
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await list_mcp_tools(
endpoint=mcp_endpoint.uri,
headers=provider_headers,
authorization=final_authorization
)
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@ -62,30 +71,60 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
# Authorization now comes from request body parameter (not provider-data)
headers = {}
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=headers,
authorization=authorization,
headers=provider_headers,
authorization=final_authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
"""
Placeholder method for extracting headers and authorization.
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
Note: MCP authentication and headers are now configured via the request body
(OpenAIResponseInputToolMCP.authorization and .headers fields) and are handled
by the responses API layer, not at the provider level.
This method is kept for interface compatibility but returns empty values
as the tool runtime provider no longer extracts per-request configuration.
For security, Authorization should not be passed via mcp_headers.
Instead, use a dedicated authorization field in the provider data.
Returns:
Tuple of (empty_headers_dict, None)
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
Raises:
ValueError: If Authorization header is found in mcp_headers (security risk)
"""
# Headers and authorization are now handled at the responses API layer
# via OpenAIResponseInputToolMCP.headers and .authorization fields
return {}, None
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and hasattr(provider_data, 'mcp_headers') and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Security check: reject Authorization header in mcp_headers
# This prevents accidentally passing inference tokens to MCP servers
for key in values.keys():
if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present
auth_value = values[key]
if auth_value.startswith("Bearer "):
authorization = auth_value[7:] # Remove "Bearer " prefix
else:
authorization = auth_value
else:
headers[key] = values[key]
return headers, authorization

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@ -35,13 +37,24 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri),
)
# Authorization now passed as request body parameter (not provider-data)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
authorization=AUTH_TOKEN, # Pass authorization as parameter
extra_headers=auth_headers, # Use old header-based approach
)
assert len(tools_list) == 2
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
@ -49,7 +62,7 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"),
authorization=AUTH_TOKEN, # Pass authorization as parameter
extra_headers=auth_headers, # Use old header-based approach
)
content = response.content
assert len(content) == 1

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Integration tests for MCP tools with complex JSON Schema support.
"""Integration tests for MCP tools with complex JSON Schema support.
Tests $ref, $defs, and other JSON Schema features through MCP integration.
"""
import json
import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -121,10 +122,22 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List runtime tools
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
tools = response
@ -160,10 +173,22 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
# Find book_flight tool (which should have $ref/$defs)
@ -205,9 +230,21 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
# Find get_weather tool
@ -247,10 +284,22 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
# Invoke tool with complex nested data
@ -262,7 +311,7 @@ class TestMCPToolInvocation:
"shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}},
}
},
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
# Should succeed without schema validation errors
@ -288,17 +337,29 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
# Test with email format
result_email = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact",
kwargs={"contact_info": "user@example.com"},
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
assert result_email.error_message is None
@ -307,7 +368,7 @@ class TestMCPToolInvocation:
result_phone = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact",
kwargs={"contact_info": "+15551234567"},
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
assert result_phone.error_message is None
@ -339,9 +400,21 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
extra_headers=auth_headers,
)
tool_defs = [
{
@ -350,6 +423,7 @@ class TestAgentWithMCPTools:
"server_label": test_toolgroup_id,
"require_approval": "never",
"allowed_tools": [tool.name for tool in tools_list],
"authorization": AUTH_TOKEN,
}
]
@ -358,7 +432,6 @@ class TestAgentWithMCPTools:
model=text_model_id,
instructions="You are a helpful assistant that can process orders and book flights.",
tools=tool_defs,
authorization=AUTH_TOKEN,
)
session_id = agent.create_session("test-session-complex")
@ -380,7 +453,6 @@ class TestAgentWithMCPTools:
}
],
stream=True,
authorization=AUTH_TOKEN,
)
)