mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: implement dedicated mcp_authorization field for remote provider
Completes the TODO for extracting authorization from a dedicated field.
What changed:
- Added mcp_authorization field to MCPProviderDataValidator
- Updated get_headers_from_request() to extract from mcp_authorization
- Authorization is now properly isolated per MCP endpoint
API usage example:
{
"provider_data": {
"mcp_headers": {
"http://mcp-server.com": {
"X-Trace-ID": "trace-123"
}
},
"mcp_authorization": {
"http://mcp-server.com": "mcp_token_xyz789"
}
}
}
Security guarantees:
- Authorization cannot be in mcp_headers (validation rejects it)
- Each MCP endpoint gets its own dedicated token
- No cross-service token leakage possible
This commit is contained in:
parent
a842c90059
commit
445135b8cc
2 changed files with 34 additions and 19 deletions
|
|
@ -12,6 +12,8 @@ from pydantic import BaseModel
|
||||||
class MCPProviderDataValidator(BaseModel):
|
class MCPProviderDataValidator(BaseModel):
|
||||||
# mcp_endpoint => dict of headers to send
|
# mcp_endpoint => dict of headers to send
|
||||||
mcp_headers: dict[str, dict[str, str]] | None = None
|
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||||
|
# mcp_endpoint => authorization token (without "Bearer " prefix)
|
||||||
|
mcp_authorization: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class MCPProviderConfig(BaseModel):
|
class MCPProviderConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,9 @@ from .config import MCPProviderConfig
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class ModelContextProtocolToolRuntimeImpl(
|
||||||
|
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
@ -45,9 +47,13 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
|
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||||
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
|
return await list_mcp_tools(
|
||||||
|
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
||||||
|
)
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, kwargs: dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||||
|
|
@ -89,7 +95,9 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
||||||
authorization = None
|
authorization = None
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data and provider_data.mcp_headers:
|
if provider_data:
|
||||||
|
# Extract headers (excluding Authorization)
|
||||||
|
if provider_data.mcp_headers:
|
||||||
for uri, values in provider_data.mcp_headers.items():
|
for uri, values in provider_data.mcp_headers.items():
|
||||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||||
continue
|
continue
|
||||||
|
|
@ -100,13 +108,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
||||||
if key.lower() == "authorization":
|
if key.lower() == "authorization":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Authorization header cannot be passed via 'mcp_headers'. "
|
"Authorization header cannot be passed via 'mcp_headers'. "
|
||||||
"Please use a dedicated authorization field in provider_data instead."
|
"Please use 'mcp_authorization' in provider_data instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect all headers (Authorization already rejected above)
|
# Collect all headers (Authorization already rejected above)
|
||||||
headers.update(values)
|
headers.update(values)
|
||||||
|
|
||||||
# TODO: Extract authorization from a dedicated field in provider_data
|
# Extract authorization from dedicated field
|
||||||
# For now, authorization remains None until the API is updated
|
if provider_data.mcp_authorization:
|
||||||
|
canonical_endpoint = canonicalize_uri(mcp_endpoint_uri)
|
||||||
|
for uri, token in provider_data.mcp_authorization.items():
|
||||||
|
if canonicalize_uri(uri) == canonical_endpoint:
|
||||||
|
authorization = token
|
||||||
|
break
|
||||||
|
|
||||||
return headers, authorization
|
return headers, authorization
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue