diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index b247a610d..288bf46e1 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -141,6 +141,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, ConnectionError | httpx.ConnectError): + return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc)) elif isinstance(exc, asyncio.TimeoutError | TimeoutError): return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 02f7aaf8a..fc8e2f377 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat raise AuthenticationRequiredError(exc) from exc if i == len(connection_strategies) - 1: raise + except* httpx.ConnectError as eg: + # Connection refused, server down, network unreachable + if i == len(connection_strategies) - 1: + error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused" + logger.error(f"MCP connection error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.TimeoutException as eg: + # Request timeout, server too slow + if i == len(connection_strategies) - 1: + error_msg = f"MCP server at {endpoint} timed out" + logger.error(f"MCP timeout error: {error_msg}") + raise TimeoutError(error_msg) from eg + else: + logger.warning( + f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.RequestError as eg: + # DNS resolution failures, network errors, invalid URLs + if i == len(connection_strategies) - 1: + # Get the first exception's message for the error string + exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error" + error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}" + logger.error(f"MCP network error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) except* McpError: if i < len(connection_strategies) - 1: logger.warning( diff --git a/tests/unit/server/test_server.py b/tests/unit/server/test_server.py index 803111fc7..f21bbdd67 100644 --- a/tests/unit/server/test_server.py +++ b/tests/unit/server/test_server.py @@ -113,6 +113,15 @@ class TestTranslateException: assert result.status_code == 504 assert result.detail == "Operation timed out: " + def test_translate_connection_error(self): + """Test that ConnectionError is translated to 502 HTTP status.""" + exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 502 + assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused" + def test_translate_not_implemented_error(self): """Test that NotImplementedError is translated to 501 HTTP status.""" exc = NotImplementedError("Not implemented")