feat: made mcp tools work with streamable http and sse

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-06-30 13:16:33 -04:00
parent c9a49a80e8
commit 5098d3f462
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -4,13 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from collections.abc import AsyncGenerator
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from pathlib import PurePosixPath
from typing import Any, cast
from urllib.parse import unquote, urlparse
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
@ -26,7 +30,7 @@ logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
@ -43,9 +47,46 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
raise
@asynccontextmanager
async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with streamablehttp_client(endpoint, headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]:
path = PurePosixPath(unquote(urlparse(endpoint).path))
if path.parts[-1] != "sse":
try:
return streamable_http_client_wrapper(endpoint, headers)
except AuthenticationRequiredError as e:
raise e # since this was an authentication error, we want to surface that instead of trying with SSE
except Exception:
return sse_client_wrapper(endpoint, headers)
else:
# most SSE MCP servers are served at endpoints ending in /sse, so default to the SSE client wrapper
# if the endpoint is explicitly an SSE endpoint
return sse_client_wrapper(endpoint, headers)
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
async with get_client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
@ -73,7 +114,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
async with get_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []