test(test_auth_exception_handler.py): add more unit testing

This commit is contained in:
Krrish Dholakia 2025-04-12 16:59:59 -07:00
parent c0c734cd75
commit a6084fa37d
2 changed files with 44 additions and 1 deletions

View file

@ -68,6 +68,7 @@ class UserAPIKeyAuthExceptionHandler:
key_name="failed-to-connect-to-db", key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db", token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name, user_id=litellm_proxy_admin_name,
request_route=route,
) )
else: else:
# raise the exception to the caller # raise the exception to the caller

View file

@ -2,7 +2,7 @@ import asyncio
import json import json
import os import os
import sys import sys
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
@ -110,3 +110,45 @@ async def test_handle_authentication_error_budget_exceeded():
) )
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
@pytest.mark.asyncio
async def test_route_passed_to_post_call_failure_hook():
"""
This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route
"""
handler = UserAPIKeyAuthExceptionHandler()
# Mock request and other dependencies
mock_request = MagicMock()
mock_request_data = {}
test_route = "/custom/route"
mock_span = None
mock_api_key = "test-key"
# Mock proxy_logging_obj.post_call_failure_hook
with patch(
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
new_callable=AsyncMock,
) as mock_post_call_failure_hook:
# Test with DB connection error
with patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": False},
):
try:
await handler._handle_authentication_error(
PrismaError(),
mock_request,
mock_request_data,
test_route,
mock_span,
mock_api_key,
)
except Exception as e:
pass
asyncio.sleep(1)
# Verify post_call_failure_hook was called with the correct route
mock_post_call_failure_hook.assert_called_once()
call_args = mock_post_call_failure_hook.call_args[1]
assert call_args["user_api_key_dict"].request_route == test_route