test_is_database_connection_error_prisma_errors

This commit is contained in:
Ishaan Jaff 2025-03-26 18:27:39 -07:00
parent 1274d27534
commit c4aab64484

View file

@ -0,0 +1,161 @@
import asyncio
import json
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
from prisma import errors as prisma_errors
from prisma.errors import (
ClientNotConnectedError,
DataError,
ForeignKeyViolationError,
HTTPClientClosedError,
MissingRequiredValueError,
PrismaError,
RawQueryError,
RecordNotFoundError,
TableNotFoundError,
UniqueViolationError,
)
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyErrorTypes, ProxyException
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
# Test is_database_connection_error method
@pytest.mark.parametrize(
"prisma_error",
[
PrismaError(),
DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
UniqueViolationError(
data={"user_facing_error": {"meta": {"table": "test_table"}}}
),
ForeignKeyViolationError(
data={"user_facing_error": {"meta": {"table": "test_table"}}}
),
MissingRequiredValueError(
data={"user_facing_error": {"meta": {"table": "test_table"}}}
),
RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
TableNotFoundError(
data={"user_facing_error": {"meta": {"table": "test_table"}}}
),
RecordNotFoundError(
data={"user_facing_error": {"meta": {"table": "test_table"}}}
),
HTTPClientClosedError(),
ClientNotConnectedError(),
],
)
def test_is_database_connection_error_prisma_errors(prisma_error):
"""
Test that all Prisma errors are considered database connection errors
"""
handler = UserAPIKeyAuthExceptionHandler()
assert handler.is_database_connection_error(prisma_error) == True
def test_is_database_connection_generic_errors():
"""
Test non-Prisma error cases for database connection checking
"""
handler = UserAPIKeyAuthExceptionHandler()
# Test with ProxyException (DB connection)
db_proxy_exception = ProxyException(
message="DB Connection Error",
type=ProxyErrorTypes.no_db_connection,
param="test-param",
)
assert handler.is_database_connection_error(db_proxy_exception) == True
# Test with non-DB error
regular_exception = Exception("Regular error")
assert handler.is_database_connection_error(regular_exception) == False
# Test should_allow_request_on_db_unavailable method
@patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": True},
)
def test_should_allow_request_on_db_unavailable_true():
handler = UserAPIKeyAuthExceptionHandler()
assert handler.should_allow_request_on_db_unavailable() == True
@patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": False},
)
def test_should_allow_request_on_db_unavailable_false():
handler = UserAPIKeyAuthExceptionHandler()
assert handler.should_allow_request_on_db_unavailable() == False
# Test _handle_authentication_error method
@pytest.mark.asyncio
async def test_handle_authentication_error_db_unavailable():
handler = UserAPIKeyAuthExceptionHandler()
# Mock request and other dependencies
mock_request = MagicMock()
mock_request_data = {}
mock_route = "/test"
mock_span = None
mock_api_key = "test-key"
# Test with DB connection error when requests are allowed
with patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": True},
):
db_error = prisma_errors.PrismaError()
result = await handler._handle_authentication_error(
db_error,
mock_request,
mock_request_data,
mock_route,
mock_span,
mock_api_key,
)
assert result.key_name == "failed-to-connect-to-db"
assert result.token == "failed-to-connect-to-db"
@pytest.mark.asyncio
async def test_handle_authentication_error_budget_exceeded():
handler = UserAPIKeyAuthExceptionHandler()
# Mock request and other dependencies
mock_request = MagicMock()
mock_request_data = {}
mock_route = "/test"
mock_span = None
mock_api_key = "test-key"
# Test with budget exceeded error
with pytest.raises(ProxyException) as exc_info:
from litellm.exceptions import BudgetExceededError
budget_error = BudgetExceededError(
message="Budget exceeded", current_cost=100, max_budget=100
)
await handler._handle_authentication_error(
budget_error,
mock_request,
mock_request_data,
mock_route,
mock_span,
mock_api_key,
)
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded