From caecaa8b31e3caa70d7969069d698e94630f9306 Mon Sep 17 00:00:00 2001 From: Calum Murray Date: Thu, 10 Jul 2025 14:15:29 -0400 Subject: [PATCH] cleanup: simplify connection code Signed-off-by: Calum Murray --- llama_stack/providers/utils/tools/mcp.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index b435eeddb..119fff932 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -46,19 +46,21 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat for i, strategy in enumerate(connection_strategies): try: if strategy == PROTOCOL_STREAMABLEHTTP: - async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP - yield session - return + client = streamablehttp_client elif strategy == PROTOCOL_SSE: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - protocol_cache[endpoint] = PROTOCOL_SSE - yield session - return + client = sse_client + else: + # this should not happen + logger.warning( + "tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http" + ) + client = streamablehttp_client + async with client(endpoint, headers=headers) as client_streams: + async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: + await session.initialize() + protocol_cache[endpoint] = strategy + yield session + return except* httpx.HTTPStatusError as eg: for exc in eg.exceptions: # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,