From 873bf8d95ae7a1d1b0f8b693887ca4c9ab1f7765 Mon Sep 17 00:00:00 2001 From: Calum Murray Date: Mon, 7 Jul 2025 10:40:20 -0400 Subject: [PATCH] feat: use ttl dict to keep track of which mcp endpoint to use Signed-off-by: Calum Murray --- llama_stack/providers/utils/tools/mcp.py | 75 ++++++++++++++---------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 56bc7db7b..b435eeddb 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -23,42 +23,57 @@ from llama_stack.apis.tools import ( ) from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.log import get_logger +from llama_stack.providers.utils.tools.ttl_dict import TTLDict logger = get_logger(__name__, category="tools") +protocol_cache = TTLDict(ttl_seconds=3600) + +PROTOCOL_STREAMABLEHTTP = "streamable_http" +PROTOCOL_SSE = "sse" +PROTOCOL_UNKNOWN = "unknown" + @asynccontextmanager async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: - try: - async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session - return - except* httpx.HTTPStatusError as eg: - for exc in eg.exceptions: - # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, - # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because - # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. - err = cast(httpx.HTTPStatusError, exc) - if err.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - except* McpError: - logger.warning("failed to connect via streamable http, falling back to sse") - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except* httpx.HTTPStatusError as eg: - for exc in eg.exceptions: - # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, - # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because - # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. - err = cast(httpx.HTTPStatusError, exc) - if err.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - raise + # we use a ttl'd dict to cache the happy path protocol for each endpoint + # but, we always fall back to trying the other protocol if we cannot initialize the session + connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE] + mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN) + if mcp_protocol == PROTOCOL_SSE: + connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP] + + for i, strategy in enumerate(connection_strategies): + try: + if strategy == PROTOCOL_STREAMABLEHTTP: + async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP + yield session + return + elif strategy == PROTOCOL_SSE: + async with sse_client(endpoint, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + protocol_cache[endpoint] = PROTOCOL_SSE + yield session + return + except* httpx.HTTPStatusError as eg: + for exc in eg.exceptions: + # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, + # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because + # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. + err = cast(httpx.HTTPStatusError, exc) + if err.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + if i == len(connection_strategies) - 1: + raise + except* McpError: + if i < len(connection_strategies) - 1: + logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}") + else: + raise async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: