fix: match mcp headers in provider data to Responses API shape (#2263)

This commit is contained in:
Ashwin Bharambe 2025-05-25 14:33:10 -07:00 committed by GitHub
parent ce33d02443
commit 9623d5d230
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 7 additions and 17 deletions

View file

@ -10,8 +10,8 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => list of headers to send
mcp_headers: dict[str, list[str]] | None = None
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
class MCPProviderConfig(BaseModel):

View file

@ -18,7 +18,7 @@ from llama_stack.apis.tools import (
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
@ -69,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(convert_header_list_to_dict(values))
headers.update(values)
return headers

View file

@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session: