Removing authorization from provider_data, enforce authorization parameter

This commit is contained in:
Omar Abdelwahab 2025-11-14 14:11:06 -08:00
parent eb545034ab
commit 575d503aaf
5 changed files with 60 additions and 159 deletions

View file

@ -104,23 +104,19 @@ client.toolgroups.register(
)
```
Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide authorization headers to send to the MCP server using the "Provider Data" abstraction provided by Llama Stack. When making an agent call,
Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide the authorization token when creating the Agent:
```python
agent = Agent(
...,
tools=["mcp::deepwiki"],
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps(
tools=[
{
"mcp_headers": {
"http://mcp.deepwiki.com/sse": {
"Authorization": "Bearer <your_access_token>",
},
},
"type": "mcp",
"server_url": "https://mcp.deepwiki.com/sse",
"server_label": "mcp::deepwiki",
"authorization": "<your_access_token>", # OAuth token (without "Bearer " prefix)
}
),
},
],
)
agent.create_turn(...)
```

View file

@ -48,15 +48,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
# Phase 2: Only use the dedicated authorization parameter
# Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization
)
async def invoke_tool(
@ -69,39 +66,39 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
# Phase 2: Only use the dedicated authorization parameter
# Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(endpoint)
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=provider_headers,
authorization=final_authorization,
authorization=authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
"""
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
Extract headers from request provider data, excluding authorization.
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
Phase 2: Authorization must be provided via the dedicated authorization parameter.
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
Args:
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
Returns:
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
dict[str, str]: Headers dictionary (without Authorization)
Raises:
ValueError: If Authorization header is found in mcp_headers
"""
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
@ -109,17 +106,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
# (Phase 2 will reject this and require the dedicated authorization parameter)
# Phase 2: Reject Authorization in mcp_headers - must use authorization parameter
for key in values.keys():
if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present
auth_value = values[key]
if auth_value.startswith("Bearer "):
authorization = auth_value[7:] # Remove "Bearer " prefix
else:
authorization = auth_value
else:
raise ValueError(
"Authorization cannot be provided via mcp_headers in provider_data. "
"Please use the dedicated 'authorization' parameter instead. "
"Example: tool_runtime.invoke_tool(..., authorization='your-token')"
)
headers[key] = values[key]
return headers, authorization
return headers

View file

@ -9,8 +9,6 @@ Integration tests for inference/chat completion with JSON Schema-based tools.
Tests that tools pass through correctly to various LLM providers.
"""
import json
import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -193,22 +191,11 @@ class TestMCPToolsInChatCompletion:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter
# Get the tools from MCP
tools_response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Convert to OpenAI format for inference

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@ -37,32 +35,26 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter (no more provider_data headers)
# This tests direct tool_runtime.invoke_tool API calls
# Without authorization, should get Unauthorized error
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
# With authorization parameter, should succeed
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers, # Use old header-based approach
authorization=AUTH_TOKEN, # Use dedicated authorization parameter
)
assert len(tools_list) == 2
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
# Invoke tool with authorization parameter
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers, # Use old header-based approach
authorization=AUTH_TOKEN, # Use dedicated authorization parameter
)
content = response.content
assert len(content) == 1

View file

@ -8,8 +8,6 @@
Tests $ref, $defs, and other JSON Schema features through MCP integration.
"""
import json
import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -122,22 +120,11 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter
# List runtime tools
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
tools = response
@ -173,22 +160,11 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter
# List tools
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Find book_flight tool (which should have $ref/$defs)
@ -230,21 +206,10 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter
response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Find get_weather tool
@ -284,22 +249,11 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Invoke tool with complex nested data
@ -311,7 +265,7 @@ class TestMCPToolInvocation:
"shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}},
}
},
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Should succeed without schema validation errors
@ -337,29 +291,18 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Phase 2: Use the dedicated authorization parameter
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
# Test with email format
result_email = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact",
kwargs={"contact_info": "user@example.com"},
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
assert result_email.error_message is None
@ -368,7 +311,7 @@ class TestMCPToolInvocation:
result_phone = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact",
kwargs={"contact_info": "+15551234567"},
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
assert result_phone.error_message is None
@ -400,21 +343,10 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri),
)
# Use old header-based approach for Phase 1 (backward compatibility)
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Use the dedicated authorization parameter
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
authorization=AUTH_TOKEN,
)
tool_defs = [
{