diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index fbf992c82..76593a4b8 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -4,13 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from enum import Enum from typing import Any, cast import httpx -from mcp import ClientSession +from mcp import ClientSession, McpError from mcp import types as mcp_types from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem from llama_stack.apis.tools import ( @@ -21,31 +24,61 @@ from llama_stack.apis.tools import ( ) from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.log import get_logger +from llama_stack.providers.utils.tools.ttl_dict import TTLDict logger = get_logger(__name__, category="tools") +protocol_cache = TTLDict(ttl_seconds=3600) + + +class MCPProtol(Enum): + UNKNOWN = 0 + STREAMABLE_HTTP = 1 + SSE = 2 + @asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except* httpx.HTTPStatusError as eg: - for exc in eg.exceptions: - # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, - # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because - # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. - err = cast(httpx.HTTPStatusError, exc) - if err.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - raise +async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: + # we use a ttl'd dict to cache the happy path protocol for each endpoint + # but, we always fall back to trying the other protocol if we cannot initialize the session + connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE] + mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN) + if mcp_protocol == MCPProtol.SSE: + connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP] + + for i, strategy in enumerate(connection_strategies): + try: + client = streamablehttp_client + if strategy == MCPProtol.SSE: + client = sse_client + async with client(endpoint, headers=headers) as client_streams: + async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: + await session.initialize() + protocol_cache[endpoint] = strategy + yield session + return + except* httpx.HTTPStatusError as eg: + for exc in eg.exceptions: + # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, + # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because + # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. + err = cast(httpx.HTTPStatusError, exc) + if err.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + if i == len(connection_strategies) - 1: + raise + except* McpError: + if i < len(connection_strategies) - 1: + logger.warning( + f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + else: + raise async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: tools = [] - async with sse_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] @@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs async def invoke_mcp_tool( endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] ) -> ToolInvocationResult: - async with sse_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: result = await session.call_tool(tool_name, kwargs) content: list[InterleavedContentItem] = [] diff --git a/llama_stack/providers/utils/tools/ttl_dict.py b/llama_stack/providers/utils/tools/ttl_dict.py new file mode 100644 index 000000000..2a2605a52 --- /dev/null +++ b/llama_stack/providers/utils/tools/ttl_dict.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from threading import RLock +from typing import Any + + +class TTLDict(dict): + """ + A dictionary with a ttl for each item + """ + + def __init__(self, ttl_seconds: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ttl_seconds = ttl_seconds + self._expires: dict[Any, Any] = {} # expires holds when an item will expire + self._lock = RLock() + + if args or kwargs: + for k, v in self.items(): + self.__setitem__(k, v) + + def __delitem__(self, key): + with self._lock: + del self._expires[key] + super().__delitem__(key) + + def __setitem__(self, key, value): + with self._lock: + self._expires[key] = time.monotonic() + self.ttl_seconds + super().__setitem__(key, value) + + def _is_expired(self, key): + if key not in self._expires: + return False + return time.monotonic() > self._expires[key] + + def __getitem__(self, key): + with self._lock: + if self._is_expired(key): + del self._expires[key] + super().__delitem__(key) + raise KeyError(f"{key} has expired and was removed") + + return super().__getitem__(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + _ = self[key] + return True + except KeyError: + return False + + def __repr__(self): + with self._lock: + for key in self.keys(): + if self._is_expired(key): + del self._expires[key] + super().__delitem__(key) + return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"