mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: made mcp tools work with streamable http and sse
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
c9a49a80e8
commit
5098d3f462
1 changed files with 45 additions and 4 deletions
|
@ -4,13 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -26,7 +30,7 @@ logger = get_logger(__name__, category="tools")
|
|||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
|
@ -43,9 +47,46 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
|||
raise
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||
try:
|
||||
async with streamablehttp_client(endpoint, headers=headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except* httpx.HTTPStatusError as eg:
|
||||
for exc in eg.exceptions:
|
||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||
err = cast(httpx.HTTPStatusError, exc)
|
||||
if err.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
raise
|
||||
|
||||
|
||||
def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]:
|
||||
path = PurePosixPath(unquote(urlparse(endpoint).path))
|
||||
if path.parts[-1] != "sse":
|
||||
try:
|
||||
return streamable_http_client_wrapper(endpoint, headers)
|
||||
except AuthenticationRequiredError as e:
|
||||
raise e # since this was an authentication error, we want to surface that instead of trying with SSE
|
||||
except Exception:
|
||||
return sse_client_wrapper(endpoint, headers)
|
||||
else:
|
||||
# most SSE MCP servers are served at endpoints ending in /sse, so default to the SSE client wrapper
|
||||
# if the endpoint is explicitly an SSE endpoint
|
||||
return sse_client_wrapper(endpoint, headers)
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
tools = []
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
async with get_client_wrapper(endpoint, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
|
@ -73,7 +114,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
async def invoke_mcp_tool(
|
||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
async with get_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool_name, kwargs)
|
||||
|
||||
content: list[InterleavedContentItem] = []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue