feat: add MCP Streamable HTTP support (#2554)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR adds support for the new Streamable HTTP transport for MCP, as
well as falling back to the SSE protocol if the Streamable HTTP
connection fails.

<!-- If resolving an issue, uncomment and update the line below -->
Closes #2542 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

---------

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-07-24 18:04:27 -04:00 committed by GitHub
parent 632cf9eb72
commit 7cc4819e90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 121 additions and 18 deletions

View file

@ -4,13 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import httpx
from mcp import ClientSession
from mcp import ClientSession, McpError
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
@ -21,31 +24,61 @@ from llama_stack.apis.tools import (
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
# we use a ttl'd dict to cache the happy path protocol for each endpoint
# but, we always fall back to trying the other protocol if we cannot initialize the session
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
if mcp_protocol == MCPProtol.SSE:
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for i, strategy in enumerate(connection_strategies):
try:
client = streamablehttp_client
if strategy == MCPProtol.SSE:
client = sse_client
async with client(endpoint, headers=headers) as client_streams:
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
await session.initialize()
protocol_cache[endpoint] = strategy
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else:
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from threading import RLock
from typing import Any
class TTLDict(dict):
"""
A dictionary with a ttl for each item
"""
def __init__(self, ttl_seconds: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ttl_seconds = ttl_seconds
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
self._lock = RLock()
if args or kwargs:
for k, v in self.items():
self.__setitem__(k, v)
def __delitem__(self, key):
with self._lock:
del self._expires[key]
super().__delitem__(key)
def __setitem__(self, key, value):
with self._lock:
self._expires[key] = time.monotonic() + self.ttl_seconds
super().__setitem__(key, value)
def _is_expired(self, key):
if key not in self._expires:
return False
return time.monotonic() > self._expires[key]
def __getitem__(self, key):
with self._lock:
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
raise KeyError(f"{key} has expired and was removed")
return super().__getitem__(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
try:
_ = self[key]
return True
except KeyError:
return False
def __repr__(self):
with self._lock:
for key in self.keys():
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"