mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
feat: add MCP Streamable HTTP support (#2554)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR adds support for the new Streamable HTTP transport for MCP, as well as falling back to the SSE protocol if the Streamable HTTP connection fails. <!-- If resolving an issue, uncomment and update the line below --> Closes #2542 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> --------- Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
632cf9eb72
commit
7cc4819e90
2 changed files with 121 additions and 18 deletions
|
@ -4,13 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession, McpError
|
||||||
from mcp import types as mcp_types
|
from mcp import types as mcp_types
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -21,31 +24,61 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPProtol(Enum):
|
||||||
|
UNKNOWN = 0
|
||||||
|
STREAMABLE_HTTP = 1
|
||||||
|
SSE = 2
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||||
try:
|
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
||||||
async with sse_client(endpoint, headers=headers) as streams:
|
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
||||||
async with ClientSession(*streams) as session:
|
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
||||||
await session.initialize()
|
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
||||||
yield session
|
if mcp_protocol == MCPProtol.SSE:
|
||||||
except* httpx.HTTPStatusError as eg:
|
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
||||||
for exc in eg.exceptions:
|
|
||||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
for i, strategy in enumerate(connection_strategies):
|
||||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
try:
|
||||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
client = streamablehttp_client
|
||||||
err = cast(httpx.HTTPStatusError, exc)
|
if strategy == MCPProtol.SSE:
|
||||||
if err.response.status_code == 401:
|
client = sse_client
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
async with client(endpoint, headers=headers) as client_streams:
|
||||||
raise
|
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||||
|
await session.initialize()
|
||||||
|
protocol_cache[endpoint] = strategy
|
||||||
|
yield session
|
||||||
|
return
|
||||||
|
except* httpx.HTTPStatusError as eg:
|
||||||
|
for exc in eg.exceptions:
|
||||||
|
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||||
|
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||||
|
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||||
|
err = cast(httpx.HTTPStatusError, exc)
|
||||||
|
if err.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
raise
|
||||||
|
except* McpError:
|
||||||
|
if i < len(connection_strategies) - 1:
|
||||||
|
logger.warning(
|
||||||
|
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||||
tools = []
|
tools = []
|
||||||
async with sse_client_wrapper(endpoint, headers) as session:
|
async with client_wrapper(endpoint, headers) as session:
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
for tool in tools_result.tools:
|
for tool in tools_result.tools:
|
||||||
parameters = []
|
parameters = []
|
||||||
|
@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
||||||
async def invoke_mcp_tool(
|
async def invoke_mcp_tool(
|
||||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
async with sse_client_wrapper(endpoint, headers) as session:
|
async with client_wrapper(endpoint, headers) as session:
|
||||||
result = await session.call_tool(tool_name, kwargs)
|
result = await session.call_tool(tool_name, kwargs)
|
||||||
|
|
||||||
content: list[InterleavedContentItem] = []
|
content: list[InterleavedContentItem] = []
|
||||||
|
|
70
llama_stack/providers/utils/tools/ttl_dict.py
Normal file
70
llama_stack/providers/utils/tools/ttl_dict.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from threading import RLock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class TTLDict(dict):
|
||||||
|
"""
|
||||||
|
A dictionary with a ttl for each item
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: float, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ttl_seconds = ttl_seconds
|
||||||
|
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
|
||||||
|
self._lock = RLock()
|
||||||
|
|
||||||
|
if args or kwargs:
|
||||||
|
for k, v in self.items():
|
||||||
|
self.__setitem__(k, v)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
with self._lock:
|
||||||
|
del self._expires[key]
|
||||||
|
super().__delitem__(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
with self._lock:
|
||||||
|
self._expires[key] = time.monotonic() + self.ttl_seconds
|
||||||
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
def _is_expired(self, key):
|
||||||
|
if key not in self._expires:
|
||||||
|
return False
|
||||||
|
return time.monotonic() > self._expires[key]
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
with self._lock:
|
||||||
|
if self._is_expired(key):
|
||||||
|
del self._expires[key]
|
||||||
|
super().__delitem__(key)
|
||||||
|
raise KeyError(f"{key} has expired and was removed")
|
||||||
|
|
||||||
|
return super().__getitem__(key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
try:
|
||||||
|
_ = self[key]
|
||||||
|
return True
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
with self._lock:
|
||||||
|
for key in self.keys():
|
||||||
|
if self._is_expired(key):
|
||||||
|
del self._expires[key]
|
||||||
|
super().__delitem__(key)
|
||||||
|
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"
|
Loading…
Add table
Add a link
Reference in a new issue