mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
cleanup: use enum for mcp protocols
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
96f67146e4
commit
ac0e3e133c
1 changed files with 16 additions and 11 deletions
|
@ -6,6 +6,7 @@
|
|||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
@ -29,30 +30,32 @@ logger = get_logger(__name__, category="tools")
|
|||
|
||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||
|
||||
PROTOCOL_STREAMABLEHTTP = "streamable_http"
|
||||
PROTOCOL_SSE = "sse"
|
||||
PROTOCOL_UNKNOWN = "unknown"
|
||||
|
||||
class MCPProtol(Enum):
|
||||
UNKNOWN = 0
|
||||
STREAMABLE_HTTP = 1
|
||||
SSE = 2
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
||||
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
||||
connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE]
|
||||
mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN)
|
||||
if mcp_protocol == PROTOCOL_SSE:
|
||||
connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP]
|
||||
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
||||
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
||||
if mcp_protocol == MCPProtol.SSE:
|
||||
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
||||
|
||||
for i, strategy in enumerate(connection_strategies):
|
||||
try:
|
||||
if strategy == PROTOCOL_STREAMABLEHTTP:
|
||||
if strategy == MCPProtol.STREAMABLE_HTTP:
|
||||
client = streamablehttp_client
|
||||
elif strategy == PROTOCOL_SSE:
|
||||
elif strategy == MCPProtol.SSE:
|
||||
client = sse_client
|
||||
else:
|
||||
# this should not happen
|
||||
logger.warning(
|
||||
"tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http"
|
||||
"tried to establish MCP connection with UNKNOWN protocol, defaulting to try with STREAMABLE_HTTP"
|
||||
)
|
||||
client = streamablehttp_client
|
||||
async with client(endpoint, headers=headers) as client_streams:
|
||||
|
@ -73,7 +76,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
|||
raise
|
||||
except* McpError:
|
||||
if i < len(connection_strategies) - 1:
|
||||
logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}")
|
||||
logger.warning(
|
||||
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue