feat: made mcp tools work with streamable http and sse

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-06-30 13:16:33 -04:00
parent c9a49a80e8
commit 5098d3f462
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -4,13 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from contextlib import asynccontextmanager from collections.abc import AsyncGenerator
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from pathlib import PurePosixPath
from typing import Any, cast from typing import Any, cast
from urllib.parse import unquote, urlparse
import httpx import httpx
from mcp import ClientSession from mcp import ClientSession
from mcp import types as mcp_types from mcp import types as mcp_types
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -26,7 +30,7 @@ logger = get_logger(__name__, category="tools")
@asynccontextmanager @asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try: try:
async with sse_client(endpoint, headers=headers) as streams: async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
@ -43,9 +47,46 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
raise raise
@asynccontextmanager
async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with streamablehttp_client(endpoint, headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]:
path = PurePosixPath(unquote(urlparse(endpoint).path))
if path.parts[-1] != "sse":
try:
return streamable_http_client_wrapper(endpoint, headers)
except AuthenticationRequiredError as e:
raise e # since this was an authentication error, we want to surface that instead of trying with SSE
except Exception:
return sse_client_wrapper(endpoint, headers)
else:
# most SSE MCP servers are served at endpoints ending in /sse, so default to the SSE client wrapper
# if the endpoint is explicitly an SSE endpoint
return sse_client_wrapper(endpoint, headers)
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = [] tools = []
async with sse_client_wrapper(endpoint, headers) as session: async with get_client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools() tools_result = await session.list_tools()
for tool in tools_result.tools: for tool in tools_result.tools:
parameters = [] parameters = []
@ -73,7 +114,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool( async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session: async with get_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs) result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = [] content: list[InterleavedContentItem] = []