fix: fallback logic to sse handles mcp errors properly

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-07-02 15:55:49 -04:00
parent e027a526c9
commit ffbea13e07
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -5,11 +5,11 @@
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from contextlib import asynccontextmanager
from typing import Any, cast
import httpx
from mcp import ClientSession
from mcp import ClientSession, McpError
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
@ -28,7 +28,24 @@ logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
except* McpError:
logger.debug("failed to connect via streamable http, falling back to sse")
# fallback to sse
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
@ -45,40 +62,10 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen
raise
@asynccontextmanager
async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with streamablehttp_client(endpoint, headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]:
try:
return streamable_http_client_wrapper(endpoint, headers)
except AuthenticationRequiredError as e:
raise e # since this was an authentication error, we want to surface that instead of trying with SSE
except Exception:
return sse_client_wrapper(endpoint, headers)
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with get_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
logger.info("listing mcp tools...")
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
@ -106,7 +93,8 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with get_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
logger.info("invoking mcp tool")
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []