diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index b435eeddb..119fff932 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -46,19 +46,21 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat for i, strategy in enumerate(connection_strategies): try: if strategy == PROTOCOL_STREAMABLEHTTP: - async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP - yield session - return + client = streamablehttp_client elif strategy == PROTOCOL_SSE: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - protocol_cache[endpoint] = PROTOCOL_SSE - yield session - return + client = sse_client + else: + # this should not happen + logger.warning( + "tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http" + ) + client = streamablehttp_client + async with client(endpoint, headers=headers) as client_streams: + async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: + await session.initialize() + protocol_cache[endpoint] = strategy + yield session + return except* httpx.HTTPStatusError as eg: for exc in eg.exceptions: # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,