diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 011ed04de..3281fe4b4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -170,10 +170,10 @@ from litellm.proxy.guardrails.init_guardrails import ( ) from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_endpoints._health_endpoints import router as health_router -from litellm.proxy.hooks.failure_handler import _PROXY_failure_handler from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) +from litellm.proxy.hooks.proxy_failure_handler import _PROXY_failure_handler from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.proxy.management_endpoints.customer_endpoints import ( router as customer_router, diff --git a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py index 5573ad096..095b15368 100644 --- a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py +++ b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py @@ -72,3 +72,40 @@ async def test_disable_spend_logs(): ) # Verify no spend logs were added assert len(mock_prisma_client.spend_log_transactions) == 0 + + +@pytest.mark.asyncio +async def test_enable_error_logs(): + """ + Test that the error logs are written to the database when disable_error_logs is False + """ + # Mock the necessary components + mock_prisma_client = AsyncMock() + mock_general_settings = {"disable_error_logs": False} + + with patch( + "litellm.proxy.proxy_server.general_settings", mock_general_settings + ), patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client): + + # Create a test exception + test_exception = Exception("Test error") + test_kwargs = { + "model": "gpt-4", + "exception": test_exception, + "optional_params": {}, + "litellm_params": {"metadata": {}}, + } + + # Call the failure handler + from litellm.proxy.proxy_server import _PROXY_failure_handler + + await _PROXY_failure_handler( + kwargs=test_kwargs, + completion_response=None, + start_time="2024-01-01", + end_time="2024-01-01", + ) + + # Verify prisma client was called to create error logs + if hasattr(mock_prisma_client, "db"): + assert mock_prisma_client.db.litellm_errorlogs.create.called