mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fix: fallback logic to sse handles mcp errors properly
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
e027a526c9
commit
ffbea13e07
1 changed files with 24 additions and 36 deletions
|
@ -5,11 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession, McpError
|
||||||
from mcp import types as mcp_types
|
from mcp import types as mcp_types
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
@ -28,7 +28,24 @@ logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||||
|
try:
|
||||||
|
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
return
|
||||||
|
except* httpx.HTTPStatusError as eg:
|
||||||
|
for exc in eg.exceptions:
|
||||||
|
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||||
|
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||||
|
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||||
|
err = cast(httpx.HTTPStatusError, exc)
|
||||||
|
if err.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
|
except* McpError:
|
||||||
|
logger.debug("failed to connect via streamable http, falling back to sse")
|
||||||
|
# fallback to sse
|
||||||
try:
|
try:
|
||||||
async with sse_client(endpoint, headers=headers) as streams:
|
async with sse_client(endpoint, headers=headers) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
|
@ -45,40 +62,10 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
|
||||||
try:
|
|
||||||
async with streamablehttp_client(endpoint, headers=headers) as (
|
|
||||||
read_stream,
|
|
||||||
write_stream,
|
|
||||||
_,
|
|
||||||
):
|
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
|
||||||
await session.initialize()
|
|
||||||
yield session
|
|
||||||
except* httpx.HTTPStatusError as eg:
|
|
||||||
for exc in eg.exceptions:
|
|
||||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
|
||||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
|
||||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
|
||||||
err = cast(httpx.HTTPStatusError, exc)
|
|
||||||
if err.response.status_code == 401:
|
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]:
|
|
||||||
try:
|
|
||||||
return streamable_http_client_wrapper(endpoint, headers)
|
|
||||||
except AuthenticationRequiredError as e:
|
|
||||||
raise e # since this was an authentication error, we want to surface that instead of trying with SSE
|
|
||||||
except Exception:
|
|
||||||
return sse_client_wrapper(endpoint, headers)
|
|
||||||
|
|
||||||
|
|
||||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||||
tools = []
|
tools = []
|
||||||
async with get_client_wrapper(endpoint, headers) as session:
|
async with client_wrapper(endpoint, headers) as session:
|
||||||
|
logger.info("listing mcp tools...")
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
for tool in tools_result.tools:
|
for tool in tools_result.tools:
|
||||||
parameters = []
|
parameters = []
|
||||||
|
@ -106,7 +93,8 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
||||||
async def invoke_mcp_tool(
|
async def invoke_mcp_tool(
|
||||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
async with get_client_wrapper(endpoint, headers) as session:
|
async with client_wrapper(endpoint, headers) as session:
|
||||||
|
logger.info("invoking mcp tool")
|
||||||
result = await session.call_tool(tool_name, kwargs)
|
result = await session.call_tool(tool_name, kwargs)
|
||||||
|
|
||||||
content: list[InterleavedContentItem] = []
|
content: list[InterleavedContentItem] = []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue