From ffbea13e07e3427173fa29d2547f1b0e38c508a2 Mon Sep 17 00:00:00 2001 From: Calum Murray Date: Wed, 2 Jul 2025 15:55:49 -0400 Subject: [PATCH] fix: fallback logic to sse handles mcp errors properly Signed-off-by: Calum Murray --- llama_stack/providers/utils/tools/mcp.py | 60 ++++++++++-------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index f32743225..64b7da3af 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -5,11 +5,11 @@ # the root directory of this source tree. from collections.abc import AsyncGenerator -from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from contextlib import asynccontextmanager from typing import Any, cast import httpx -from mcp import ClientSession +from mcp import ClientSession, McpError from mcp import types as mcp_types from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -28,7 +28,24 @@ logger = get_logger(__name__, category="tools") @asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: +async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: + try: + async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + yield session + return + except* httpx.HTTPStatusError as eg: + for exc in eg.exceptions: + # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, + # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because + # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. + err = cast(httpx.HTTPStatusError, exc) + if err.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + except* McpError: + logger.debug("failed to connect via streamable http, falling back to sse") + # fallback to sse try: async with sse_client(endpoint, headers=headers) as streams: async with ClientSession(*streams) as session: @@ -45,40 +62,10 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen raise -@asynccontextmanager -async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: - try: - async with streamablehttp_client(endpoint, headers=headers) as ( - read_stream, - write_stream, - _, - ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session - except* httpx.HTTPStatusError as eg: - for exc in eg.exceptions: - # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, - # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because - # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. - err = cast(httpx.HTTPStatusError, exc) - if err.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - raise - - -def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]: - try: - return streamable_http_client_wrapper(endpoint, headers) - except AuthenticationRequiredError as e: - raise e # since this was an authentication error, we want to surface that instead of trying with SSE - except Exception: - return sse_client_wrapper(endpoint, headers) - - async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: tools = [] - async with get_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: + logger.info("listing mcp tools...") tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] @@ -106,7 +93,8 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs async def invoke_mcp_tool( endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] ) -> ToolInvocationResult: - async with get_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: + logger.info("invoking mcp tool") result = await session.call_tool(tool_name, kwargs) content: list[InterleavedContentItem] = []