fix: fallback logic to sse handles mcp errors properly

Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
Calum Murray 2025-07-02 15:55:49 -04:00
parent e027a526c9
commit ffbea13e07
No known key found for this signature in database
GPG key ID: B67F01AEB13FE187

View file

@ -5,11 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, cast from typing import Any, cast
import httpx import httpx
from mcp import ClientSession from mcp import ClientSession, McpError
from mcp import types as mcp_types from mcp import types as mcp_types
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
@ -28,7 +28,24 @@ logger = get_logger(__name__, category="tools")
@asynccontextmanager @asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
except* McpError:
logger.debug("failed to connect via streamable http, falling back to sse")
# fallback to sse
try: try:
async with sse_client(endpoint, headers=headers) as streams: async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
@ -45,40 +62,10 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen
raise raise
@asynccontextmanager
async def streamable_http_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try:
async with streamablehttp_client(endpoint, headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
def get_client_wrapper(endpoint: str, headers: dict[str, str]) -> _AsyncGeneratorContextManager[ClientSession, Any]:
try:
return streamable_http_client_wrapper(endpoint, headers)
except AuthenticationRequiredError as e:
raise e # since this was an authentication error, we want to surface that instead of trying with SSE
except Exception:
return sse_client_wrapper(endpoint, headers)
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = [] tools = []
async with get_client_wrapper(endpoint, headers) as session: async with client_wrapper(endpoint, headers) as session:
logger.info("listing mcp tools...")
tools_result = await session.list_tools() tools_result = await session.list_tools()
for tool in tools_result.tools: for tool in tools_result.tools:
parameters = [] parameters = []
@ -106,7 +93,8 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool( async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
async with get_client_wrapper(endpoint, headers) as session: async with client_wrapper(endpoint, headers) as session:
logger.info("invoking mcp tool")
result = await session.call_tool(tool_name, kwargs) result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = [] content: list[InterleavedContentItem] = []