From 7cc4819e90b172ff78903fbb9eae5bb506728c56 Mon Sep 17 00:00:00 2001 From: Calum Murray Date: Thu, 24 Jul 2025 18:04:27 -0400 Subject: [PATCH] feat: add MCP Streamable HTTP support (#2554) # What does this PR do? This PR adds support for the new Streamable HTTP transport for MCP, as well as falling back to the SSE protocol if the Streamable HTTP connection fails. Closes #2542 ## Test Plan --------- Signed-off-by: Calum Murray --- llama_stack/providers/utils/tools/mcp.py | 69 +++++++++++++----- llama_stack/providers/utils/tools/ttl_dict.py | 70 +++++++++++++++++++ 2 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 llama_stack/providers/utils/tools/ttl_dict.py diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index fbf992c82..76593a4b8 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -4,13 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from enum import Enum from typing import Any, cast import httpx -from mcp import ClientSession +from mcp import ClientSession, McpError from mcp import types as mcp_types from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem from llama_stack.apis.tools import ( @@ -21,31 +24,61 @@ from llama_stack.apis.tools import ( ) from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.log import get_logger +from llama_stack.providers.utils.tools.ttl_dict import TTLDict logger = get_logger(__name__, category="tools") +protocol_cache = TTLDict(ttl_seconds=3600) + + +class MCPProtol(Enum): + UNKNOWN = 0 + STREAMABLE_HTTP = 1 + SSE = 2 + @asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except* httpx.HTTPStatusError as eg: - for exc in eg.exceptions: - # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, - # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because - # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. - err = cast(httpx.HTTPStatusError, exc) - if err.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - raise +async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: + # we use a ttl'd dict to cache the happy path protocol for each endpoint + # but, we always fall back to trying the other protocol if we cannot initialize the session + connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE] + mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN) + if mcp_protocol == MCPProtol.SSE: + connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP] + + for i, strategy in enumerate(connection_strategies): + try: + client = streamablehttp_client + if strategy == MCPProtol.SSE: + client = sse_client + async with client(endpoint, headers=headers) as client_streams: + async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: + await session.initialize() + protocol_cache[endpoint] = strategy + yield session + return + except* httpx.HTTPStatusError as eg: + for exc in eg.exceptions: + # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, + # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because + # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. + err = cast(httpx.HTTPStatusError, exc) + if err.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + if i == len(connection_strategies) - 1: + raise + except* McpError: + if i < len(connection_strategies) - 1: + logger.warning( + f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + else: + raise async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: tools = [] - async with sse_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] @@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs async def invoke_mcp_tool( endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] ) -> ToolInvocationResult: - async with sse_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: result = await session.call_tool(tool_name, kwargs) content: list[InterleavedContentItem] = [] diff --git a/llama_stack/providers/utils/tools/ttl_dict.py b/llama_stack/providers/utils/tools/ttl_dict.py new file mode 100644 index 000000000..2a2605a52 --- /dev/null +++ b/llama_stack/providers/utils/tools/ttl_dict.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from threading import RLock +from typing import Any + + +class TTLDict(dict): + """ + A dictionary with a ttl for each item + """ + + def __init__(self, ttl_seconds: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ttl_seconds = ttl_seconds + self._expires: dict[Any, Any] = {} # expires holds when an item will expire + self._lock = RLock() + + if args or kwargs: + for k, v in self.items(): + self.__setitem__(k, v) + + def __delitem__(self, key): + with self._lock: + del self._expires[key] + super().__delitem__(key) + + def __setitem__(self, key, value): + with self._lock: + self._expires[key] = time.monotonic() + self.ttl_seconds + super().__setitem__(key, value) + + def _is_expired(self, key): + if key not in self._expires: + return False + return time.monotonic() > self._expires[key] + + def __getitem__(self, key): + with self._lock: + if self._is_expired(key): + del self._expires[key] + super().__delitem__(key) + raise KeyError(f"{key} has expired and was removed") + + return super().__getitem__(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + _ = self[key] + return True + except KeyError: + return False + + def __repr__(self): + with self._lock: + for key in self.keys(): + if self._is_expired(key): + del self._expires[key] + super().__delitem__(key) + return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"