diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ac235ca0d6..991ab98c85 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -8,6 +8,7 @@ from dataclasses import fields from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +import httpx from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator from typing_extensions import Annotated, TypedDict @@ -1940,6 +1941,9 @@ class ProxyErrorTypes(str, enum.Enum): not_found_error = "not_found_error" +DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout) + + class SSOUserDefinedValues(TypedDict): models: List[str] user_id: str diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 179a6fcf54..3b43ec32e1 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -22,6 +22,7 @@ from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, LiteLLM_EndUserTable, LiteLLM_JWTAuth, LiteLLM_OrganizationTable, @@ -753,7 +754,7 @@ async def get_key_object( ) return _response - except httpx.ConnectError as e: + except DB_CONNECTION_ERROR_TYPES as e: return await _handle_failed_db_connection_for_get_key_object(e=e) except Exception: raise Exception( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 1c1a41e0eb..abfaf60f02 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -31,7 +31,11 @@ from litellm.litellm_core_utils.duration_parser import ( duration_in_seconds, get_last_day_of_month, ) -from litellm.proxy._types import ProxyErrorTypes, ProxyException +from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, + ProxyErrorTypes, + ProxyException, +) try: import backoff @@ -2591,30 +2595,17 @@ async def update_spend( # noqa: PLR0915 {} ) # Clear the remaining transactions after processing all batches in the loop. break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - import traceback - - error_msg = ( - f"LiteLLM Prisma Client Exception - update user spend: {str(e)}" + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e ### UPDATE END-USER TABLE ### verbose_proxy_logger.debug( @@ -2652,30 +2643,17 @@ async def update_spend( # noqa: PLR0915 {} ) # Clear the remaining transactions after processing all batches in the loop. break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - import traceback - - error_msg = ( - f"LiteLLM Prisma Client Exception - update end-user spend: {str(e)}" + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e ### UPDATE KEY TABLE ### verbose_proxy_logger.debug( @@ -2703,30 +2681,17 @@ async def update_spend( # noqa: PLR0915 {} ) # Clear the remaining transactions after processing all batches in the loop. break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - import traceback - - error_msg = ( - f"LiteLLM Prisma Client Exception - update key spend: {str(e)}" + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e ### UPDATE TEAM TABLE ### verbose_proxy_logger.debug( @@ -2759,30 +2724,17 @@ async def update_spend( # noqa: PLR0915 {} ) # Clear the remaining transactions after processing all batches in the loop. break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - import traceback - - error_msg = ( - f"LiteLLM Prisma Client Exception - update team spend: {str(e)}" + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e ### UPDATE TEAM Membership TABLE with spend ### if len(prisma_client.team_member_list_transactons.keys()) > 0: @@ -2809,30 +2761,17 @@ async def update_spend( # noqa: PLR0915 {} ) # Clear the remaining transactions after processing all batches in the loop. break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - import traceback - - error_msg = ( - f"LiteLLM Prisma Client Exception - update team spend: {str(e)}" + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e ### UPDATE ORG TABLE ### if len(prisma_client.org_list_transactons.keys()) > 0: @@ -2855,30 +2794,17 @@ async def update_spend( # noqa: PLR0915 {} ) # Clear the remaining transactions after processing all batches in the loop. break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - import traceback - - error_msg = ( - f"LiteLLM Prisma Client Exception - update org spend: {str(e)}" + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e ### UPDATE SPEND LOGS ### verbose_proxy_logger.debug( @@ -2889,7 +2815,7 @@ async def update_spend( # noqa: PLR0915 MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval if len(prisma_client.spend_log_transactions) > 0: - for _ in range(n_retry_times + 1): + for i in range(n_retry_times + 1): start_time = time.time() try: base_url = os.getenv("SPEND_LOGS_URL", None) @@ -2913,9 +2839,9 @@ async def update_spend( # noqa: PLR0915 logs_to_process = prisma_client.spend_log_transactions[ :MAX_LOGS_PER_INTERVAL ] - for i in range(0, len(logs_to_process), BATCH_SIZE): + for j in range(0, len(logs_to_process), BATCH_SIZE): # Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE - batch = logs_to_process[i : i + BATCH_SIZE] + batch = logs_to_process[j : j + BATCH_SIZE] # Convert datetime strings to Date objects batch_with_dates = [ @@ -2943,32 +2869,50 @@ async def update_spend( # noqa: PLR0915 f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" ) break - except httpx.ReadTimeout: + except DB_CONNECTION_ERROR_TYPES as e: if i is None: i = 0 - if i >= n_retry_times: # If we've reached the maximum number of retries - raise # Re-raise the last exception + if ( + i >= n_retry_times + ): # If we've reached the maximum number of retries raise the exception + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # type: ignore except Exception as e: - import traceback + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) - error_msg = ( - f"LiteLLM Prisma Client Exception - update spend logs: {str(e)}" - ) - print_verbose(error_msg) - error_traceback = error_msg + "\n" + traceback.format_exc() - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - proxy_logging_obj.failure_handler( - original_exception=e, - duration=_duration, - call_type="update_spend", - traceback_str=error_traceback, - ) - ) - raise e + +def _raise_failed_update_spend_exception( + e: Exception, start_time: float, proxy_logging_obj: ProxyLogging +): + """ + Raise an exception for failed update spend logs + + - Calls proxy_logging_obj.failure_handler to log the error + - Ensures error messages says "Non-Blocking" + """ + import traceback + + error_msg = ( + f"[Non-Blocking]LiteLLM Prisma Client Exception - update spend logs: {str(e)}" + ) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="update_spend", + traceback_str=error_traceback, + ) + ) + raise e def _is_projected_spend_over_limit( diff --git a/tests/proxy_unit_tests/test_update_spend.py b/tests/proxy_unit_tests/test_update_spend.py new file mode 100644 index 0000000000..6efc68a077 --- /dev/null +++ b/tests/proxy_unit_tests/test_update_spend.py @@ -0,0 +1,289 @@ +import asyncio +import os +import sys +from unittest.mock import Mock +from litellm.proxy.utils import _get_redoc_url, _get_docs_url + +import pytest +from fastapi import Request + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from unittest.mock import MagicMock, patch, AsyncMock + + +import httpx +from litellm.proxy.utils import update_spend, DB_CONNECTION_ERROR_TYPES + + +class MockPrismaClient: + def __init__(self): + self.db = MagicMock() + self.spend_log_transactions = [] + self.user_list_transactons = {} + self.end_user_list_transactons = {} + self.key_list_transactons = {} + self.team_list_transactons = {} + self.team_member_list_transactons = {} + self.org_list_transactons = {} + + def jsonify_object(self, obj): + return obj + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_type", + [ + httpx.ConnectError("Failed to connect"), + httpx.ReadError("Failed to read response"), + httpx.ReadTimeout("Request timed out"), + ], +) +async def test_update_spend_logs_connection_errors(error_type): + """Test retry mechanism for different connection error types""" + # Setup + prisma_client = MockPrismaClient() + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + + # Add test spend logs + prisma_client.spend_log_transactions = [ + {"id": "1", "spend": 10}, + {"id": "2", "spend": 20}, + ] + + # Mock the database to fail with connection error twice then succeed + create_many_mock = AsyncMock() + create_many_mock.side_effect = [ + error_type, # First attempt fails + error_type, # Second attempt fails + error_type, # Third attempt fails + None, # Fourth attempt succeeds + ] + + prisma_client.db.litellm_spendlogs.create_many = create_many_mock + + # Execute + await update_spend(prisma_client, None, proxy_logging_obj) + + # Verify + assert create_many_mock.call_count == 4 # Should have tried 3 times + assert ( + len(prisma_client.spend_log_transactions) == 0 + ) # Should have cleared after success + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_type", + [ + httpx.ConnectError("Failed to connect"), + httpx.ReadError("Failed to read response"), + httpx.ReadTimeout("Request timed out"), + ], +) +async def test_update_spend_logs_max_retries_exceeded(error_type): + """Test that each connection error type properly fails after max retries""" + # Setup + prisma_client = MockPrismaClient() + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + + # Add test spend logs + prisma_client.spend_log_transactions = [ + {"id": "1", "spend": 10}, + {"id": "2", "spend": 20}, + ] + + # Mock the database to always fail + create_many_mock = AsyncMock(side_effect=error_type) + + prisma_client.db.litellm_spendlogs.create_many = create_many_mock + + # Execute and verify it raises after max retries + with pytest.raises(type(error_type)) as exc_info: + await update_spend(prisma_client, None, proxy_logging_obj) + + # Verify error message matches + assert str(exc_info.value) == str(error_type) + # Verify retry attempts (initial try + 4 retries) + assert create_many_mock.call_count == 4 + + await asyncio.sleep(2) + # Verify failure handler was called + assert proxy_logging_obj.failure_handler.call_count == 1 + + +@pytest.mark.asyncio +async def test_update_spend_logs_non_connection_error(): + """Test handling of non-connection related errors""" + # Setup + prisma_client = MockPrismaClient() + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + + # Add test spend logs + prisma_client.spend_log_transactions = [ + {"id": "1", "spend": 10}, + {"id": "2", "spend": 20}, + ] + + # Mock a different type of error (not connection-related) + unexpected_error = ValueError("Unexpected database error") + create_many_mock = AsyncMock(side_effect=unexpected_error) + + prisma_client.db.litellm_spendlogs.create_many = create_many_mock + + # Execute and verify it raises immediately without retrying + with pytest.raises(ValueError) as exc_info: + await update_spend(prisma_client, None, proxy_logging_obj) + + # Verify error message + assert str(exc_info.value) == "Unexpected database error" + # Verify only tried once (no retries for non-connection errors) + assert create_many_mock.call_count == 1 + # Verify failure handler was called + assert proxy_logging_obj.failure_handler.called + + +@pytest.mark.asyncio +async def test_update_spend_logs_exponential_backoff(): + """Test that exponential backoff is working correctly""" + # Setup + prisma_client = MockPrismaClient() + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + + # Add test spend logs + prisma_client.spend_log_transactions = [{"id": "1", "spend": 10}] + + # Track sleep times + sleep_times = [] + + # Mock asyncio.sleep to track delay times + async def mock_sleep(seconds): + sleep_times.append(seconds) + + # Mock the database to fail with connection errors + create_many_mock = AsyncMock( + side_effect=[ + httpx.ConnectError("Failed to connect"), # First attempt + httpx.ConnectError("Failed to connect"), # Second attempt + None, # Third attempt succeeds + ] + ) + + prisma_client.db.litellm_spendlogs.create_many = create_many_mock + + # Apply mocks + with patch("asyncio.sleep", mock_sleep): + await update_spend(prisma_client, None, proxy_logging_obj) + + # Verify exponential backoff + assert len(sleep_times) == 2 # Should have slept twice + assert sleep_times[0] == 1 # First retry after 2^0 seconds + assert sleep_times[1] == 2 # Second retry after 2^1 seconds + + +@pytest.mark.asyncio +async def test_update_spend_logs_multiple_batches_success(): + """ + Test successful processing of multiple batches of spend logs + + Code sets batch size to 100. This test creates 150 logs, so it should make 2 batches. + """ + # Setup + prisma_client = MockPrismaClient() + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + + # Create 150 test spend logs (1.5x BATCH_SIZE) + prisma_client.spend_log_transactions = [ + {"id": str(i), "spend": 10} for i in range(150) + ] + + create_many_mock = AsyncMock(return_value=None) + prisma_client.db.litellm_spendlogs.create_many = create_many_mock + + # Execute + await update_spend(prisma_client, None, proxy_logging_obj) + + # Verify + assert create_many_mock.call_count == 2 # Should have made 2 batch calls + + # Get the actual data from each batch call + first_batch = create_many_mock.call_args_list[0][1]["data"] + second_batch = create_many_mock.call_args_list[1][1]["data"] + + # Verify batch sizes + assert len(first_batch) == 100 + assert len(second_batch) == 50 + + # Verify exact IDs in each batch + expected_first_batch_ids = {str(i) for i in range(100)} + expected_second_batch_ids = {str(i) for i in range(100, 150)} + + actual_first_batch_ids = {item["id"] for item in first_batch} + actual_second_batch_ids = {item["id"] for item in second_batch} + + assert actual_first_batch_ids == expected_first_batch_ids + assert actual_second_batch_ids == expected_second_batch_ids + + # Verify all logs were processed + assert len(prisma_client.spend_log_transactions) == 0 + + +@pytest.mark.asyncio +async def test_update_spend_logs_multiple_batches_with_failure(): + """ + Test processing of multiple batches where one batch fails. + Creates 400 logs (4 batches) with one batch failing but eventually succeeding after retry. + """ + # Setup + prisma_client = MockPrismaClient() + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + + # Create 400 test spend logs (4x BATCH_SIZE) + prisma_client.spend_log_transactions = [ + {"id": str(i), "spend": 10} for i in range(400) + ] + + # Mock to fail on second batch first attempt, then succeed + call_count = 0 + + async def create_many_side_effect(**kwargs): + nonlocal call_count + call_count += 1 + # Fail on the second batch's first attempt + if call_count == 2: + raise httpx.ConnectError("Failed to connect") + return None + + create_many_mock = AsyncMock(side_effect=create_many_side_effect) + prisma_client.db.litellm_spendlogs.create_many = create_many_mock + + # Execute + await update_spend(prisma_client, None, proxy_logging_obj) + + # Verify + assert create_many_mock.call_count == 6 # 4 batches + 2 retries for failed batch + + # Verify all batches were processed + all_processed_logs = [] + for call in create_many_mock.call_args_list: + all_processed_logs.extend(call[1]["data"]) + + # Verify all IDs were processed + processed_ids = {item["id"] for item in all_processed_logs} + + # these should have ids 0-399 + print("all processed ids", sorted(processed_ids, key=int)) + expected_ids = {str(i) for i in range(400)} + assert processed_ids == expected_ids + + # Verify all logs were cleared from transactions + assert len(prisma_client.spend_log_transactions) == 0