diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index 64bb67b58..f0bcc7190 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -2125,3 +2125,73 @@ async def test_proxy_server_prisma_setup_invalid_db(): if _old_db_url: os.environ["DATABASE_URL"] = _old_db_url + + +@pytest.mark.asyncio +async def test_async_log_proxy_authentication_errors(): + """ + Test if async_log_proxy_authentication_errors correctly logs authentication errors through custom loggers + """ + import json + from fastapi import Request + from litellm.proxy.utils import ProxyLogging + from litellm.caching import DualCache + from litellm.integrations.custom_logger import CustomLogger + + # Create a mock custom logger to verify it's called + class MockCustomLogger(CustomLogger): + def __init__(self): + self.called = False + self.exception_logged = None + self.request_data_logged = None + self.user_api_key_dict_logged = None + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + self.called = True + self.exception_logged = original_exception + self.request_data_logged = request_data + print("logged request_data", request_data) + if isinstance(request_data, AsyncMock): + self.request_data_logged = ( + await request_data() + ) # get the actual value from AsyncMock + else: + self.request_data_logged = request_data + self.user_api_key_dict_logged = user_api_key_dict + + # Create test data + test_data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]} + + # Create a mock request + request = Request(scope={"type": "http", "method": "POST"}) + request._json = AsyncMock(return_value=test_data) + + # Create a test exception + test_exception = Exception("Invalid API Key") + + # Initialize ProxyLogging + mock_logger = MockCustomLogger() + litellm.callbacks = [mock_logger] + proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + # Call the method + await proxy_logging_obj.async_log_proxy_authentication_errors( + original_exception=test_exception, + request=request, + parent_otel_span=None, + api_key="test-key", + ) + + # Verify the mock logger was called with correct parameters + assert mock_logger.called == True + assert mock_logger.exception_logged == test_exception + assert mock_logger.request_data_logged == test_data + assert mock_logger.user_api_key_dict_logged is not None + assert ( + mock_logger.user_api_key_dict_logged.token is not None + ) # token should be hashed