diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index d400159b2..b8c5e77fd 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -10,8 +10,8 @@ from pydantic import BaseModel class MCPProviderDataValidator(BaseModel): - # mcp_endpoint => list of headers to send - mcp_headers: dict[str, list[str]] | None = None + # mcp_endpoint => dict of headers to send + mcp_headers: dict[str, dict[str, str]] | None = None class MCPProviderConfig(BaseModel): diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 9603bf97e..a9b252dfe 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -18,7 +18,7 @@ from llama_stack.apis.tools import ( from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate -from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools +from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools from .config import MCPProviderConfig @@ -69,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime for uri, values in provider_data.mcp_headers.items(): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): continue - headers.update(convert_header_list_to_dict(values)) + headers.update(values) return headers diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 7bc372f07..f024693a0 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): raise -def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]: - headers = {} - for header in header_list: - parts = header.split(":") - if len(parts) == 2: - k, v = parts - headers[k.strip()] = v.strip() - return headers - - async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: tools = [] async with sse_client_wrapper(endpoint, headers) as session: diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index 28b2e43c1..72aa25e52 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -40,9 +40,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server): provider_data = { "mcp_headers": { - uri: [ - f"Authorization: Bearer {AUTH_TOKEN}", - ], + uri: { + "Authorization": f"Bearer {AUTH_TOKEN}", + }, }, } auth_headers = {