cleanup: use enum for mcp protocols

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-07-10 14:27:02 -04:00
parent 96f67146e4
commit ac0e3e133c
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -6,6 +6,7 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import httpx
@ -29,30 +30,32 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
PROTOCOL_STREAMABLEHTTP = "streamable_http"
PROTOCOL_SSE = "sse"
PROTOCOL_UNKNOWN = "unknown"
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
# we use a ttl'd dict to cache the happy path protocol for each endpoint
# but, we always fall back to trying the other protocol if we cannot initialize the session
connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE]
mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN)
if mcp_protocol == PROTOCOL_SSE:
connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP]
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
if mcp_protocol == MCPProtol.SSE:
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for i, strategy in enumerate(connection_strategies):
try:
if strategy == PROTOCOL_STREAMABLEHTTP:
if strategy == MCPProtol.STREAMABLE_HTTP:
client = streamablehttp_client
elif strategy == PROTOCOL_SSE:
elif strategy == MCPProtol.SSE:
client = sse_client
else:
# this should not happen
logger.warning(
"tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http"
"tried to establish MCP connection with UNKNOWN protocol, defaulting to try with STREAMABLE_HTTP"
)
client = streamablehttp_client
async with client(endpoint, headers=headers) as client_streams:
@ -73,7 +76,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
raise
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}")
logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else:
raise