cleanup: use enum for mcp protocols

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-07-10 14:27:02 -04:00
parent 96f67146e4
commit ac0e3e133c
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -6,6 +6,7 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast from typing import Any, cast
import httpx import httpx
@ -29,30 +30,32 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600) protocol_cache = TTLDict(ttl_seconds=3600)
PROTOCOL_STREAMABLEHTTP = "streamable_http"
PROTOCOL_SSE = "sse" class MCPProtol(Enum):
PROTOCOL_UNKNOWN = "unknown" UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager @asynccontextmanager
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
# we use a ttl'd dict to cache the happy path protocol for each endpoint # we use a ttl'd dict to cache the happy path protocol for each endpoint
# but, we always fall back to trying the other protocol if we cannot initialize the session # but, we always fall back to trying the other protocol if we cannot initialize the session
connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE] connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN) mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
if mcp_protocol == PROTOCOL_SSE: if mcp_protocol == MCPProtol.SSE:
connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP] connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for i, strategy in enumerate(connection_strategies): for i, strategy in enumerate(connection_strategies):
try: try:
if strategy == PROTOCOL_STREAMABLEHTTP: if strategy == MCPProtol.STREAMABLE_HTTP:
client = streamablehttp_client client = streamablehttp_client
elif strategy == PROTOCOL_SSE: elif strategy == MCPProtol.SSE:
client = sse_client client = sse_client
else: else:
# this should not happen # this should not happen
logger.warning( logger.warning(
"tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http" "tried to establish MCP connection with UNKNOWN protocol, defaulting to try with STREAMABLE_HTTP"
) )
client = streamablehttp_client client = streamablehttp_client
async with client(endpoint, headers=headers) as client_streams: async with client(endpoint, headers=headers) as client_streams:
@ -73,7 +76,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
raise raise
except* McpError: except* McpError:
if i < len(connection_strategies) - 1: if i < len(connection_strategies) - 1:
logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}") logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else: else:
raise raise