diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index d41ee2d9bf..268e3bb1b2 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -68,6 +68,7 @@ class UserAPIKeyAuthExceptionHandler: key_name="failed-to-connect-to-db", token="failed-to-connect-to-db", user_id=litellm_proxy_admin_name, + request_route=route, ) else: # raise the exception to the caller diff --git a/tests/litellm/proxy/auth/test_auth_exception_handler.py b/tests/litellm/proxy/auth/test_auth_exception_handler.py index 224bf24b57..3e780c6ee9 100644 --- a/tests/litellm/proxy/auth/test_auth_exception_handler.py +++ b/tests/litellm/proxy/auth/test_auth_exception_handler.py @@ -2,7 +2,7 @@ import asyncio import json import os import sys -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException, Request, status @@ -110,3 +110,45 @@ async def test_handle_authentication_error_budget_exceeded(): ) assert exc_info.value.type == ProxyErrorTypes.budget_exceeded + + +@pytest.mark.asyncio +async def test_route_passed_to_post_call_failure_hook(): + """ + This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route + """ + handler = UserAPIKeyAuthExceptionHandler() + + # Mock request and other dependencies + mock_request = MagicMock() + mock_request_data = {} + test_route = "/custom/route" + mock_span = None + mock_api_key = "test-key" + + # Mock proxy_logging_obj.post_call_failure_hook + with patch( + "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook", + new_callable=AsyncMock, + ) as mock_post_call_failure_hook: + # Test with DB connection error + with patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": False}, + ): + try: + await handler._handle_authentication_error( + PrismaError(), + mock_request, + mock_request_data, + test_route, + mock_span, + mock_api_key, + ) + except Exception as e: + pass + asyncio.sleep(1) + # Verify post_call_failure_hook was called with the correct route + mock_post_call_failure_hook.assert_called_once() + call_args = mock_post_call_failure_hook.call_args[1] + assert call_args["user_api_key_dict"].request_route == test_route