mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: use ttl dict to keep track of which mcp endpoint to use
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
8c8d558a5c
commit
873bf8d95a
1 changed files with 45 additions and 30 deletions
|
@ -23,42 +23,57 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||||
|
|
||||||
|
PROTOCOL_STREAMABLEHTTP = "streamable_http"
|
||||||
|
PROTOCOL_SSE = "sse"
|
||||||
|
PROTOCOL_UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||||
try:
|
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
||||||
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
|
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
connection_strategies = [PROTOCOL_STREAMABLEHTTP, PROTOCOL_SSE]
|
||||||
await session.initialize()
|
mcp_protocol = protocol_cache.get(endpoint, default=PROTOCOL_UNKNOWN)
|
||||||
yield session
|
if mcp_protocol == PROTOCOL_SSE:
|
||||||
return
|
connection_strategies = [PROTOCOL_SSE, PROTOCOL_STREAMABLEHTTP]
|
||||||
except* httpx.HTTPStatusError as eg:
|
|
||||||
for exc in eg.exceptions:
|
for i, strategy in enumerate(connection_strategies):
|
||||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
try:
|
||||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
if strategy == PROTOCOL_STREAMABLEHTTP:
|
||||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
|
||||||
err = cast(httpx.HTTPStatusError, exc)
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
if err.response.status_code == 401:
|
await session.initialize()
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP
|
||||||
except* McpError:
|
yield session
|
||||||
logger.warning("failed to connect via streamable http, falling back to sse")
|
return
|
||||||
try:
|
elif strategy == PROTOCOL_SSE:
|
||||||
async with sse_client(endpoint, headers=headers) as streams:
|
async with sse_client(endpoint, headers=headers) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
yield session
|
protocol_cache[endpoint] = PROTOCOL_SSE
|
||||||
except* httpx.HTTPStatusError as eg:
|
yield session
|
||||||
for exc in eg.exceptions:
|
return
|
||||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
except* httpx.HTTPStatusError as eg:
|
||||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
for exc in eg.exceptions:
|
||||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||||
err = cast(httpx.HTTPStatusError, exc)
|
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||||
if err.response.status_code == 401:
|
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
err = cast(httpx.HTTPStatusError, exc)
|
||||||
raise
|
if err.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
raise
|
||||||
|
except* McpError:
|
||||||
|
if i < len(connection_strategies) - 1:
|
||||||
|
logger.warning(f"failed to connect via {strategy}, falling back to {connection_strategies[i + 1]}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue