mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
cleanup: use enum for mcp protocols
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
96f67146e4
commit
ac0e3e133c
1 changed files with 16 additions and 11 deletions
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -29,30 +30,32 @@ logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||||
|
|
||||||
PROTOCOL_STREAMABLEHTTP = "streamable_http"
|
|
||||||
PROTOCOL_SSE = "sse"
|
class MCPProtol(Enum):
|
||||||
PROTOCOL_UNKNOWN = "unknown"
|
UNKNOWN = 0
|
||||||
|
STREAMABLE_HTTP = 1
|
||||||
|
SSE = 2
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||||
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
||||||
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
||||||
connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE]
|
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
||||||
mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN)
|
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
||||||
if mcp_protocol == PROTOCOL_SSE:
|
if mcp_protocol == MCPProtol.SSE:
|
||||||
connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP]
|
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
||||||
|
|
||||||
for i, strategy in enumerate(connection_strategies):
|
for i, strategy in enumerate(connection_strategies):
|
||||||
try:
|
try:
|
||||||
if strategy == PROTOCOL_STREAMABLEHTTP:
|
if strategy == MCPProtol.STREAMABLE_HTTP:
|
||||||
client = streamablehttp_client
|
client = streamablehttp_client
|
||||||
elif strategy == PROTOCOL_SSE:
|
elif strategy == MCPProtol.SSE:
|
||||||
client = sse_client
|
client = sse_client
|
||||||
else:
|
else:
|
||||||
# this should not happen
|
# this should not happen
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http"
|
"tried to establish MCP connection with UNKNOWN protocol, defaulting to try with STREAMABLE_HTTP"
|
||||||
)
|
)
|
||||||
client = streamablehttp_client
|
client = streamablehttp_client
|
||||||
async with client(endpoint, headers=headers) as client_streams:
|
async with client(endpoint, headers=headers) as client_streams:
|
||||||
|
@ -73,7 +76,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
||||||
raise
|
raise
|
||||||
except* McpError:
|
except* McpError:
|
||||||
if i < len(connection_strategies) - 1:
|
if i < len(connection_strategies) - 1:
|
||||||
logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}")
|
logger.warning(
|
||||||
|
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue