diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index e5b0bc3f7..61402707c 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -44,8 +44,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - headers = await self.get_headers_from_request(mcp_endpoint.uri) - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers) + headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) + return await list_mcp_tools( + endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization + ) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) @@ -55,24 +57,44 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - headers = await self.get_headers_from_request(endpoint) + headers, authorization = await self.get_headers_from_request(endpoint) return await invoke_mcp_tool( endpoint=endpoint, tool_name=tool_name, kwargs=kwargs, headers=headers, + authorization=authorization, ) - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: + async def get_headers_from_request( + self, mcp_endpoint_uri: str + ) -> tuple[dict[str, str], str | None]: + """ + Extract headers and authorization from request provider data. + + Returns: + Tuple of (headers_dict, authorization_token) + - headers_dict: All headers except Authorization + - authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None + """ + def canonicalize_uri(uri: str) -> str: return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" headers = {} + authorization = None provider_data = self.get_request_provider_data() if provider_data and provider_data.mcp_headers: for uri, values in provider_data.mcp_headers.items(): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): continue - headers.update(values) - return headers + # Extract Authorization header separately for security + for key, value in values.items(): + if key.lower() == "authorization": + # Remove "Bearer " prefix if present + authorization = value.removeprefix("Bearer ").strip() + else: + headers[key] = value + + return headers, authorization