mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
UI Tests / ui-tests (22) (push) Successful in 1m20s
Pre-commit / pre-commit (push) Successful in 2m37s
What does this PR do? Fixes error handling when MCP server connections fail. Instead of returning generic 500 errors, now provides descriptive error messages with proper HTTP status codes. Closes #3107 Test Plan Before fix: curl -X GET "http://localhost:8321/v1/tool-runtime/list-tools?tool_group_id=bad-mcp-server" Returns: {"detail": "Internal server error: An unexpected error occurred."} (500) After fix: curl -X GET "http://localhost:8321/v1/tool-runtime/list-tools?tool_group_id=bad-mcp-server" Returns: {"error": {"detail": "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"}} (502) Tests: - Added unit test for ConnectionError → 502 translation - Manually tested with unreachable MCP servers (connection refused)
196 lines
7.9 KiB
Python
196 lines
7.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from unittest.mock import Mock
|
|
|
|
from fastapi import HTTPException
|
|
from openai import BadRequestError
|
|
from pydantic import ValidationError
|
|
|
|
from llama_stack.core.access_control.access_control import AccessDeniedError
|
|
from llama_stack.core.datatypes import AuthenticationRequiredError
|
|
from llama_stack.core.server.server import translate_exception
|
|
|
|
|
|
class TestTranslateException:
|
|
"""Test cases for the translate_exception function."""
|
|
|
|
def test_translate_access_denied_error(self):
|
|
"""Test that AccessDeniedError is translated to 403 HTTP status."""
|
|
exc = AccessDeniedError()
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 403
|
|
assert result.detail == "Permission denied: Insufficient permissions"
|
|
|
|
def test_translate_access_denied_error_with_context(self):
|
|
"""Test that AccessDeniedError with context includes detailed information."""
|
|
from llama_stack.core.datatypes import User
|
|
|
|
# Create mock user and resource
|
|
user = User("test-user", {"roles": ["user"], "teams": ["dev"]})
|
|
|
|
# Create a simple mock object that implements the ProtectedResource protocol
|
|
class MockResource:
|
|
def __init__(self, type: str, identifier: str, owner=None):
|
|
self.type = type
|
|
self.identifier = identifier
|
|
self.owner = owner
|
|
|
|
resource = MockResource("vector_db", "test-db")
|
|
|
|
exc = AccessDeniedError("create", resource, user)
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 403
|
|
assert "test-user" in result.detail
|
|
assert "vector_db::test-db" in result.detail
|
|
assert "create" in result.detail
|
|
assert "roles=['user']" in result.detail
|
|
assert "teams=['dev']" in result.detail
|
|
|
|
def test_translate_permission_error(self):
|
|
"""Test that PermissionError is translated to 403 HTTP status."""
|
|
exc = PermissionError("Permission denied")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 403
|
|
assert result.detail == "Permission denied: Permission denied"
|
|
|
|
def test_translate_value_error(self):
|
|
"""Test that ValueError is translated to 400 HTTP status."""
|
|
exc = ValueError("Invalid input")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 400
|
|
assert result.detail == "Invalid value: Invalid input"
|
|
|
|
def test_translate_bad_request_error(self):
|
|
"""Test that BadRequestError is translated to 400 HTTP status."""
|
|
# Create a mock response for BadRequestError
|
|
mock_response = Mock()
|
|
mock_response.status_code = 400
|
|
mock_response.headers = {}
|
|
|
|
exc = BadRequestError("Bad request", response=mock_response, body="Bad request")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 400
|
|
assert result.detail == "Bad request"
|
|
|
|
def test_translate_authentication_required_error(self):
|
|
"""Test that AuthenticationRequiredError is translated to 401 HTTP status."""
|
|
exc = AuthenticationRequiredError("Authentication required")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 401
|
|
assert result.detail == "Authentication required: Authentication required"
|
|
|
|
def test_translate_timeout_error(self):
|
|
"""Test that TimeoutError is translated to 504 HTTP status."""
|
|
exc = TimeoutError("Operation timed out")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 504
|
|
assert result.detail == "Operation timed out: Operation timed out"
|
|
|
|
def test_translate_asyncio_timeout_error(self):
|
|
"""Test that asyncio.TimeoutError is translated to 504 HTTP status."""
|
|
exc = TimeoutError()
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 504
|
|
assert result.detail == "Operation timed out: "
|
|
|
|
def test_translate_connection_error(self):
|
|
"""Test that ConnectionError is translated to 502 HTTP status."""
|
|
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 502
|
|
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
|
|
|
|
def test_translate_not_implemented_error(self):
|
|
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
|
exc = NotImplementedError("Not implemented")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 501
|
|
assert result.detail == "Not implemented: Not implemented"
|
|
|
|
def test_translate_validation_error(self):
|
|
"""Test that ValidationError is translated to 400 HTTP status with proper format."""
|
|
# Create a mock validation error using proper Pydantic error format
|
|
exc = ValidationError.from_exception_data(
|
|
"TestModel",
|
|
[
|
|
{
|
|
"loc": ("field", "nested"),
|
|
"msg": "field required",
|
|
"type": "missing",
|
|
}
|
|
],
|
|
)
|
|
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 400
|
|
assert "errors" in result.detail
|
|
assert len(result.detail["errors"]) == 1
|
|
assert result.detail["errors"][0]["loc"] == ["field", "nested"]
|
|
assert result.detail["errors"][0]["msg"] == "Field required"
|
|
assert result.detail["errors"][0]["type"] == "missing"
|
|
|
|
def test_translate_generic_exception(self):
|
|
"""Test that generic exceptions are translated to 500 HTTP status."""
|
|
exc = Exception("Unexpected error")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 500
|
|
assert result.detail == "Internal server error: An unexpected error occurred."
|
|
|
|
def test_translate_runtime_error(self):
|
|
"""Test that RuntimeError is translated to 500 HTTP status."""
|
|
exc = RuntimeError("Runtime error")
|
|
result = translate_exception(exc)
|
|
|
|
assert isinstance(result, HTTPException)
|
|
assert result.status_code == 500
|
|
assert result.detail == "Internal server error: An unexpected error occurred."
|
|
|
|
def test_multiple_access_denied_scenarios(self):
|
|
"""Test various scenarios that should result in 403 status codes."""
|
|
# Test AccessDeniedError (uses enhanced message)
|
|
exc1 = AccessDeniedError()
|
|
result1 = translate_exception(exc1)
|
|
assert isinstance(result1, HTTPException)
|
|
assert result1.status_code == 403
|
|
assert result1.detail == "Permission denied: Insufficient permissions"
|
|
|
|
# Test PermissionError (uses generic message)
|
|
exc2 = PermissionError("No permission")
|
|
result2 = translate_exception(exc2)
|
|
assert isinstance(result2, HTTPException)
|
|
assert result2.status_code == 403
|
|
assert result2.detail == "Permission denied: No permission"
|
|
|
|
exc3 = PermissionError("Access denied")
|
|
result3 = translate_exception(exc3)
|
|
assert isinstance(result3, HTTPException)
|
|
assert result3.status_code == 403
|
|
assert result3.detail == "Permission denied: Access denied"
|