mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: implement dedicated mcp_authorization field for remote provider
Completes the TODO for extracting authorization from a dedicated field.
What changed:
- Added mcp_authorization field to MCPProviderDataValidator
- Updated get_headers_from_request() to extract from mcp_authorization
- Authorization is now properly isolated per MCP endpoint
API usage example:
{
"provider_data": {
"mcp_headers": {
"http://mcp-server.com": {
"X-Trace-ID": "trace-123"
}
},
"mcp_authorization": {
"http://mcp-server.com": "mcp_token_xyz789"
}
}
}
Security guarantees:
- Authorization cannot be in mcp_headers (validation rejects it)
- Each MCP endpoint gets its own dedicated token
- No cross-service token leakage possible
This commit is contained in:
parent
a842c90059
commit
445135b8cc
2 changed files with 34 additions and 19 deletions
|
|
@ -12,6 +12,8 @@ from pydantic import BaseModel
|
|||
class MCPProviderDataValidator(BaseModel):
|
||||
# mcp_endpoint => dict of headers to send
|
||||
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||
# mcp_endpoint => authorization token (without "Bearer " prefix)
|
||||
mcp_authorization: dict[str, str] | None = None
|
||||
|
||||
|
||||
class MCPProviderConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -25,7 +25,9 @@ from .config import MCPProviderConfig
|
|||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class ModelContextProtocolToolRuntimeImpl(
|
||||
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
|
|
@ -45,9 +47,13 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
|
||||
return await list_mcp_tools(
|
||||
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||
|
|
@ -89,24 +95,31 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
authorization = None
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and provider_data.mcp_headers:
|
||||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
if provider_data:
|
||||
# Extract headers (excluding Authorization)
|
||||
if provider_data.mcp_headers:
|
||||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
|
||||
# Security check: reject Authorization header in mcp_headers
|
||||
# This prevents accidentally passing inference tokens to MCP servers
|
||||
for key in values.keys():
|
||||
if key.lower() == "authorization":
|
||||
raise ValueError(
|
||||
"Authorization header cannot be passed via 'mcp_headers'. "
|
||||
"Please use a dedicated authorization field in provider_data instead."
|
||||
)
|
||||
# Security check: reject Authorization header in mcp_headers
|
||||
# This prevents accidentally passing inference tokens to MCP servers
|
||||
for key in values.keys():
|
||||
if key.lower() == "authorization":
|
||||
raise ValueError(
|
||||
"Authorization header cannot be passed via 'mcp_headers'. "
|
||||
"Please use 'mcp_authorization' in provider_data instead."
|
||||
)
|
||||
|
||||
# Collect all headers (Authorization already rejected above)
|
||||
headers.update(values)
|
||||
# Collect all headers (Authorization already rejected above)
|
||||
headers.update(values)
|
||||
|
||||
# TODO: Extract authorization from a dedicated field in provider_data
|
||||
# For now, authorization remains None until the API is updated
|
||||
# Extract authorization from dedicated field
|
||||
if provider_data.mcp_authorization:
|
||||
canonical_endpoint = canonicalize_uri(mcp_endpoint_uri)
|
||||
for uri, token in provider_data.mcp_authorization.items():
|
||||
if canonicalize_uri(uri) == canonical_endpoint:
|
||||
authorization = token
|
||||
break
|
||||
|
||||
return headers, authorization
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue