feat: use ttl dict to keep track of which mcp endpoint to use

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-07-07 10:40:20 -04:00
parent 8c8d558a5c
commit 873bf8d95a
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -23,42 +23,57 @@ from llama_stack.apis.tools import (
) )
from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
PROTOCOL_STREAMABLEHTTP = "streamable_http"
PROTOCOL_SSE = "sse"
PROTOCOL_UNKNOWN = "unknown"
@asynccontextmanager @asynccontextmanager
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try: # we use a ttl'd dict to cache the happy path protocol for each endpoint
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): # but, we always fall back to trying the other protocol if we cannot initialize the session
async with ClientSession(read_stream, write_stream) as session: connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE]
await session.initialize() mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN)
yield session if mcp_protocol == PROTOCOL_SSE:
return connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP]
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions: for i, strategy in enumerate(connection_strategies):
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, try:
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because if strategy == PROTOCOL_STREAMABLEHTTP:
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
err = cast(httpx.HTTPStatusError, exc) async with ClientSession(read_stream, write_stream) as session:
if err.response.status_code == 401: await session.initialize()
raise AuthenticationRequiredError(exc) from exc protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP
except* McpError: yield session
logger.warning("failed to connect via streamable http, falling back to sse") return
try: elif strategy == PROTOCOL_SSE:
async with sse_client(endpoint, headers=headers) as streams: async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
yield session protocol_cache[endpoint] = PROTOCOL_SSE
except* httpx.HTTPStatusError as eg: yield session
for exc in eg.exceptions: return
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, except* httpx.HTTPStatusError as eg:
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because for exc in eg.exceptions:
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
err = cast(httpx.HTTPStatusError, exc) # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
if err.response.status_code == 401: # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
raise AuthenticationRequiredError(exc) from exc err = cast(httpx.HTTPStatusError, exc)
raise if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}")
else:
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: