From 5098d3f4624d491d4c2b8ea6c422290aabcf4038 Mon Sep 17 00:00:00 2001 From: Calum Murray Date: Mon, 30 Jun 2025 13:16:33 -0400 Subject: [PATCH] feat: made mcp tools work with streamable http and sse Signed-off-by: Calum Murray --- llama_stack/providers/utils/tools/mcp.py | 49 ++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index fbf992c82..9ca457ec0 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -4,13 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from pathlib import PurePosixPath from typing import Any, cast +from urllib.parse import unquote, urlparse import httpx from mcp import ClientSession from mcp import types as mcp_types from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem from llama_stack.apis.tools import ( @@ -26,7 +30,7 @@ logger = get_logger(__name__, category="tools") @asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: try: async with sse_client(endpoint, headers=headers) as streams: async with ClientSession(*streams) as session: @@ -43,9 +47,46 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): raise +@asynccontextmanager +async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: + try: + async with streamablehttp_client(endpoint, headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + yield session + except* httpx.HTTPStatusError as eg: + for exc in eg.exceptions: + # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, + # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because + # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. + err = cast(httpx.HTTPStatusError, exc) + if err.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + raise + + +def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]: + path = PurePosixPath(unquote(urlparse(endpoint).path)) + if path.parts[-1] != "sse": + try: + return streamable_http_client_wrapper(endpoint, headers) + except AuthenticationRequiredError as e: + raise e # since this was an authentication error, we want to surface that instead of trying with SSE + except Exception: + return sse_client_wrapper(endpoint, headers) + else: + # most SSE MCP servers are served at endpoints ending in /sse, so default to the SSE client wrapper + # if the endpoint is explicitly an SSE endpoint + return sse_client_wrapper(endpoint, headers) + + async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: tools = [] - async with sse_client_wrapper(endpoint, headers) as session: + async with get_client_wrapper(endpoint, headers) as session: tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] @@ -73,7 +114,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs async def invoke_mcp_tool( endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] ) -> ToolInvocationResult: - async with sse_client_wrapper(endpoint, headers) as session: + async with get_client_wrapper(endpoint, headers) as session: result = await session.call_tool(tool_name, kwargs) content: list[InterleavedContentItem] = []