diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index e6ea16cd0..d0753933c 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from enum import Enum from typing import Any, cast import httpx @@ -29,30 +30,32 @@ logger = get_logger(__name__, category="tools") protocol_cache = TTLDict(ttl_seconds=3600) -PROTOCOL_STREAMABLEHTTP = "streamable_http" -PROTOCOL_SSE = "sse" -PROTOCOL_UNKNOWN = "unknown" + +class MCPProtol(Enum): + UNKNOWN = 0 + STREAMABLE_HTTP = 1 + SSE = 2 @asynccontextmanager async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: # we use a ttl'd dict to cache the happy path protocol for each endpoint # but, we always fall back to trying the other protocol if we cannot initialize the session - connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE] - mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN) - if mcp_protocol == PROTOCOL_SSE: - connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP] + connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE] + mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN) + if mcp_protocol == MCPProtol.SSE: + connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP] for i, strategy in enumerate(connection_strategies): try: - if strategy == PROTOCOL_STREAMABLEHTTP: + if strategy == MCPProtol.STREAMABLE_HTTP: client = streamablehttp_client - elif strategy == PROTOCOL_SSE: + elif strategy == MCPProtol.SSE: client = sse_client else: # this should not happen logger.warning( - "tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http" + "tried to establish MCP connection with UNKNOWN protocol, defaulting to try with STREAMABLE_HTTP" ) client = streamablehttp_client async with client(endpoint, headers=headers) as client_streams: @@ -73,7 +76,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat raise except* McpError: if i < len(connection_strategies) - 1: - logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}") + logger.warning( + f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) else: raise