mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(bug fix) SpendLogs update DB catch all possible DB errors for retrying (#7082)
* catch DB_CONNECTION_ERROR_TYPES * fix DB retry mechanism for SpendLog updates * use DB_CONNECTION_ERROR_TYPES in auth checks * fix exp back off for writing SpendLogs * use _raise_failed_update_spend_exception to ensure errors print as NON blocking * test_update_spend_logs_multiple_batches_with_failure
This commit is contained in:
parent
6ec920d0b4
commit
b78eb6654d
4 changed files with 377 additions and 139 deletions
|
@ -8,6 +8,7 @@ from dataclasses import fields
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
|
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import Annotated, TypedDict
|
||||||
|
|
||||||
|
@ -1940,6 +1941,9 @@ class ProxyErrorTypes(str, enum.Enum):
|
||||||
not_found_error = "not_found_error"
|
not_found_error = "not_found_error"
|
||||||
|
|
||||||
|
|
||||||
|
DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout)
|
||||||
|
|
||||||
|
|
||||||
class SSOUserDefinedValues(TypedDict):
|
class SSOUserDefinedValues(TypedDict):
|
||||||
models: List[str]
|
models: List[str]
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
|
@ -22,6 +22,7 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_OrganizationTable,
|
LiteLLM_OrganizationTable,
|
||||||
|
@ -753,7 +754,7 @@ async def get_key_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
except httpx.ConnectError as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -31,7 +31,11 @@ from litellm.litellm_core_utils.duration_parser import (
|
||||||
duration_in_seconds,
|
duration_in_seconds,
|
||||||
get_last_day_of_month,
|
get_last_day_of_month,
|
||||||
)
|
)
|
||||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
from litellm.proxy._types import (
|
||||||
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import backoff
|
import backoff
|
||||||
|
@ -2591,30 +2595,17 @@ async def update_spend( # noqa: PLR0915
|
||||||
{}
|
{}
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update user spend: {str(e)}"
|
|
||||||
)
|
)
|
||||||
print_verbose(error_msg)
|
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
|
||||||
end_time = time.time()
|
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.failure_handler(
|
|
||||||
original_exception=e,
|
|
||||||
duration=_duration,
|
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### UPDATE END-USER TABLE ###
|
### UPDATE END-USER TABLE ###
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -2652,30 +2643,17 @@ async def update_spend( # noqa: PLR0915
|
||||||
{}
|
{}
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update end-user spend: {str(e)}"
|
|
||||||
)
|
)
|
||||||
print_verbose(error_msg)
|
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
|
||||||
end_time = time.time()
|
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.failure_handler(
|
|
||||||
original_exception=e,
|
|
||||||
duration=_duration,
|
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### UPDATE KEY TABLE ###
|
### UPDATE KEY TABLE ###
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -2703,30 +2681,17 @@ async def update_spend( # noqa: PLR0915
|
||||||
{}
|
{}
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update key spend: {str(e)}"
|
|
||||||
)
|
)
|
||||||
print_verbose(error_msg)
|
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
|
||||||
end_time = time.time()
|
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.failure_handler(
|
|
||||||
original_exception=e,
|
|
||||||
duration=_duration,
|
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### UPDATE TEAM TABLE ###
|
### UPDATE TEAM TABLE ###
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -2759,30 +2724,17 @@ async def update_spend( # noqa: PLR0915
|
||||||
{}
|
{}
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
|
|
||||||
)
|
)
|
||||||
print_verbose(error_msg)
|
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
|
||||||
end_time = time.time()
|
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.failure_handler(
|
|
||||||
original_exception=e,
|
|
||||||
duration=_duration,
|
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### UPDATE TEAM Membership TABLE with spend ###
|
### UPDATE TEAM Membership TABLE with spend ###
|
||||||
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
||||||
|
@ -2809,30 +2761,17 @@ async def update_spend( # noqa: PLR0915
|
||||||
{}
|
{}
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
|
|
||||||
)
|
)
|
||||||
print_verbose(error_msg)
|
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
|
||||||
end_time = time.time()
|
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.failure_handler(
|
|
||||||
original_exception=e,
|
|
||||||
duration=_duration,
|
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### UPDATE ORG TABLE ###
|
### UPDATE ORG TABLE ###
|
||||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||||
|
@ -2855,30 +2794,17 @@ async def update_spend( # noqa: PLR0915
|
||||||
{}
|
{}
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update org spend: {str(e)}"
|
|
||||||
)
|
)
|
||||||
print_verbose(error_msg)
|
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
|
||||||
end_time = time.time()
|
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.failure_handler(
|
|
||||||
original_exception=e,
|
|
||||||
duration=_duration,
|
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### UPDATE SPEND LOGS ###
|
### UPDATE SPEND LOGS ###
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -2889,7 +2815,7 @@ async def update_spend( # noqa: PLR0915
|
||||||
MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval
|
MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval
|
||||||
|
|
||||||
if len(prisma_client.spend_log_transactions) > 0:
|
if len(prisma_client.spend_log_transactions) > 0:
|
||||||
for _ in range(n_retry_times + 1):
|
for i in range(n_retry_times + 1):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
base_url = os.getenv("SPEND_LOGS_URL", None)
|
base_url = os.getenv("SPEND_LOGS_URL", None)
|
||||||
|
@ -2913,9 +2839,9 @@ async def update_spend( # noqa: PLR0915
|
||||||
logs_to_process = prisma_client.spend_log_transactions[
|
logs_to_process = prisma_client.spend_log_transactions[
|
||||||
:MAX_LOGS_PER_INTERVAL
|
:MAX_LOGS_PER_INTERVAL
|
||||||
]
|
]
|
||||||
for i in range(0, len(logs_to_process), BATCH_SIZE):
|
for j in range(0, len(logs_to_process), BATCH_SIZE):
|
||||||
# Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE
|
# Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE
|
||||||
batch = logs_to_process[i : i + BATCH_SIZE]
|
batch = logs_to_process[j : j + BATCH_SIZE]
|
||||||
|
|
||||||
# Convert datetime strings to Date objects
|
# Convert datetime strings to Date objects
|
||||||
batch_with_dates = [
|
batch_with_dates = [
|
||||||
|
@ -2943,32 +2869,50 @@ async def update_spend( # noqa: PLR0915
|
||||||
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
|
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except httpx.ReadTimeout:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if i is None:
|
if i is None:
|
||||||
i = 0
|
i = 0
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if (
|
||||||
raise # Re-raise the last exception
|
i >= n_retry_times
|
||||||
|
): # If we've reached the maximum number of retries raise the exception
|
||||||
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
|
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # type: ignore
|
await asyncio.sleep(2**i) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
_raise_failed_update_spend_exception(
|
||||||
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
|
|
||||||
error_msg = (
|
|
||||||
f"LiteLLM Prisma Client Exception - update spend logs: {str(e)}"
|
def _raise_failed_update_spend_exception(
|
||||||
)
|
e: Exception, start_time: float, proxy_logging_obj: ProxyLogging
|
||||||
print_verbose(error_msg)
|
):
|
||||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
"""
|
||||||
end_time = time.time()
|
Raise an exception for failed update spend logs
|
||||||
_duration = end_time - start_time
|
|
||||||
asyncio.create_task(
|
- Calls proxy_logging_obj.failure_handler to log the error
|
||||||
proxy_logging_obj.failure_handler(
|
- Ensures error messages says "Non-Blocking"
|
||||||
original_exception=e,
|
"""
|
||||||
duration=_duration,
|
import traceback
|
||||||
call_type="update_spend",
|
|
||||||
traceback_str=error_traceback,
|
error_msg = (
|
||||||
)
|
f"[Non-Blocking]LiteLLM Prisma Client Exception - update spend logs: {str(e)}"
|
||||||
)
|
)
|
||||||
raise e
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.failure_handler(
|
||||||
|
original_exception=e,
|
||||||
|
duration=_duration,
|
||||||
|
call_type="update_spend",
|
||||||
|
traceback_str=error_traceback,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def _is_projected_spend_over_limit(
|
def _is_projected_spend_over_limit(
|
||||||
|
|
289
tests/proxy_unit_tests/test_update_spend.py
Normal file
289
tests/proxy_unit_tests/test_update_spend.py
Normal file
|
@ -0,0 +1,289 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from litellm.proxy.utils import update_spend, DB_CONNECTION_ERROR_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
class MockPrismaClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.db = MagicMock()
|
||||||
|
self.spend_log_transactions = []
|
||||||
|
self.user_list_transactons = {}
|
||||||
|
self.end_user_list_transactons = {}
|
||||||
|
self.key_list_transactons = {}
|
||||||
|
self.team_list_transactons = {}
|
||||||
|
self.team_member_list_transactons = {}
|
||||||
|
self.org_list_transactons = {}
|
||||||
|
|
||||||
|
def jsonify_object(self, obj):
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"error_type",
|
||||||
|
[
|
||||||
|
httpx.ConnectError("Failed to connect"),
|
||||||
|
httpx.ReadError("Failed to read response"),
|
||||||
|
httpx.ReadTimeout("Request timed out"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_update_spend_logs_connection_errors(error_type):
|
||||||
|
"""Test retry mechanism for different connection error types"""
|
||||||
|
# Setup
|
||||||
|
prisma_client = MockPrismaClient()
|
||||||
|
proxy_logging_obj = MagicMock()
|
||||||
|
proxy_logging_obj.failure_handler = AsyncMock()
|
||||||
|
|
||||||
|
# Add test spend logs
|
||||||
|
prisma_client.spend_log_transactions = [
|
||||||
|
{"id": "1", "spend": 10},
|
||||||
|
{"id": "2", "spend": 20},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock the database to fail with connection error twice then succeed
|
||||||
|
create_many_mock = AsyncMock()
|
||||||
|
create_many_mock.side_effect = [
|
||||||
|
error_type, # First attempt fails
|
||||||
|
error_type, # Second attempt fails
|
||||||
|
error_type, # Third attempt fails
|
||||||
|
None, # Fourth attempt succeeds
|
||||||
|
]
|
||||||
|
|
||||||
|
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
await update_spend(prisma_client, None, proxy_logging_obj)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert create_many_mock.call_count == 4 # Should have tried 3 times
|
||||||
|
assert (
|
||||||
|
len(prisma_client.spend_log_transactions) == 0
|
||||||
|
) # Should have cleared after success
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"error_type",
|
||||||
|
[
|
||||||
|
httpx.ConnectError("Failed to connect"),
|
||||||
|
httpx.ReadError("Failed to read response"),
|
||||||
|
httpx.ReadTimeout("Request timed out"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_update_spend_logs_max_retries_exceeded(error_type):
|
||||||
|
"""Test that each connection error type properly fails after max retries"""
|
||||||
|
# Setup
|
||||||
|
prisma_client = MockPrismaClient()
|
||||||
|
proxy_logging_obj = MagicMock()
|
||||||
|
proxy_logging_obj.failure_handler = AsyncMock()
|
||||||
|
|
||||||
|
# Add test spend logs
|
||||||
|
prisma_client.spend_log_transactions = [
|
||||||
|
{"id": "1", "spend": 10},
|
||||||
|
{"id": "2", "spend": 20},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock the database to always fail
|
||||||
|
create_many_mock = AsyncMock(side_effect=error_type)
|
||||||
|
|
||||||
|
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
|
||||||
|
|
||||||
|
# Execute and verify it raises after max retries
|
||||||
|
with pytest.raises(type(error_type)) as exc_info:
|
||||||
|
await update_spend(prisma_client, None, proxy_logging_obj)
|
||||||
|
|
||||||
|
# Verify error message matches
|
||||||
|
assert str(exc_info.value) == str(error_type)
|
||||||
|
# Verify retry attempts (initial try + 4 retries)
|
||||||
|
assert create_many_mock.call_count == 4
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
# Verify failure handler was called
|
||||||
|
assert proxy_logging_obj.failure_handler.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_spend_logs_non_connection_error():
|
||||||
|
"""Test handling of non-connection related errors"""
|
||||||
|
# Setup
|
||||||
|
prisma_client = MockPrismaClient()
|
||||||
|
proxy_logging_obj = MagicMock()
|
||||||
|
proxy_logging_obj.failure_handler = AsyncMock()
|
||||||
|
|
||||||
|
# Add test spend logs
|
||||||
|
prisma_client.spend_log_transactions = [
|
||||||
|
{"id": "1", "spend": 10},
|
||||||
|
{"id": "2", "spend": 20},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock a different type of error (not connection-related)
|
||||||
|
unexpected_error = ValueError("Unexpected database error")
|
||||||
|
create_many_mock = AsyncMock(side_effect=unexpected_error)
|
||||||
|
|
||||||
|
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
|
||||||
|
|
||||||
|
# Execute and verify it raises immediately without retrying
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await update_spend(prisma_client, None, proxy_logging_obj)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert str(exc_info.value) == "Unexpected database error"
|
||||||
|
# Verify only tried once (no retries for non-connection errors)
|
||||||
|
assert create_many_mock.call_count == 1
|
||||||
|
# Verify failure handler was called
|
||||||
|
assert proxy_logging_obj.failure_handler.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_spend_logs_exponential_backoff():
|
||||||
|
"""Test that exponential backoff is working correctly"""
|
||||||
|
# Setup
|
||||||
|
prisma_client = MockPrismaClient()
|
||||||
|
proxy_logging_obj = MagicMock()
|
||||||
|
proxy_logging_obj.failure_handler = AsyncMock()
|
||||||
|
|
||||||
|
# Add test spend logs
|
||||||
|
prisma_client.spend_log_transactions = [{"id": "1", "spend": 10}]
|
||||||
|
|
||||||
|
# Track sleep times
|
||||||
|
sleep_times = []
|
||||||
|
|
||||||
|
# Mock asyncio.sleep to track delay times
|
||||||
|
async def mock_sleep(seconds):
|
||||||
|
sleep_times.append(seconds)
|
||||||
|
|
||||||
|
# Mock the database to fail with connection errors
|
||||||
|
create_many_mock = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.ConnectError("Failed to connect"), # First attempt
|
||||||
|
httpx.ConnectError("Failed to connect"), # Second attempt
|
||||||
|
None, # Third attempt succeeds
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
|
||||||
|
|
||||||
|
# Apply mocks
|
||||||
|
with patch("asyncio.sleep", mock_sleep):
|
||||||
|
await update_spend(prisma_client, None, proxy_logging_obj)
|
||||||
|
|
||||||
|
# Verify exponential backoff
|
||||||
|
assert len(sleep_times) == 2 # Should have slept twice
|
||||||
|
assert sleep_times[0] == 1 # First retry after 2^0 seconds
|
||||||
|
assert sleep_times[1] == 2 # Second retry after 2^1 seconds
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_spend_logs_multiple_batches_success():
|
||||||
|
"""
|
||||||
|
Test successful processing of multiple batches of spend logs
|
||||||
|
|
||||||
|
Code sets batch size to 100. This test creates 150 logs, so it should make 2 batches.
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
prisma_client = MockPrismaClient()
|
||||||
|
proxy_logging_obj = MagicMock()
|
||||||
|
proxy_logging_obj.failure_handler = AsyncMock()
|
||||||
|
|
||||||
|
# Create 150 test spend logs (1.5x BATCH_SIZE)
|
||||||
|
prisma_client.spend_log_transactions = [
|
||||||
|
{"id": str(i), "spend": 10} for i in range(150)
|
||||||
|
]
|
||||||
|
|
||||||
|
create_many_mock = AsyncMock(return_value=None)
|
||||||
|
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
await update_spend(prisma_client, None, proxy_logging_obj)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert create_many_mock.call_count == 2 # Should have made 2 batch calls
|
||||||
|
|
||||||
|
# Get the actual data from each batch call
|
||||||
|
first_batch = create_many_mock.call_args_list[0][1]["data"]
|
||||||
|
second_batch = create_many_mock.call_args_list[1][1]["data"]
|
||||||
|
|
||||||
|
# Verify batch sizes
|
||||||
|
assert len(first_batch) == 100
|
||||||
|
assert len(second_batch) == 50
|
||||||
|
|
||||||
|
# Verify exact IDs in each batch
|
||||||
|
expected_first_batch_ids = {str(i) for i in range(100)}
|
||||||
|
expected_second_batch_ids = {str(i) for i in range(100, 150)}
|
||||||
|
|
||||||
|
actual_first_batch_ids = {item["id"] for item in first_batch}
|
||||||
|
actual_second_batch_ids = {item["id"] for item in second_batch}
|
||||||
|
|
||||||
|
assert actual_first_batch_ids == expected_first_batch_ids
|
||||||
|
assert actual_second_batch_ids == expected_second_batch_ids
|
||||||
|
|
||||||
|
# Verify all logs were processed
|
||||||
|
assert len(prisma_client.spend_log_transactions) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_spend_logs_multiple_batches_with_failure():
|
||||||
|
"""
|
||||||
|
Test processing of multiple batches where one batch fails.
|
||||||
|
Creates 400 logs (4 batches) with one batch failing but eventually succeeding after retry.
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
prisma_client = MockPrismaClient()
|
||||||
|
proxy_logging_obj = MagicMock()
|
||||||
|
proxy_logging_obj.failure_handler = AsyncMock()
|
||||||
|
|
||||||
|
# Create 400 test spend logs (4x BATCH_SIZE)
|
||||||
|
prisma_client.spend_log_transactions = [
|
||||||
|
{"id": str(i), "spend": 10} for i in range(400)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock to fail on second batch first attempt, then succeed
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def create_many_side_effect(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
# Fail on the second batch's first attempt
|
||||||
|
if call_count == 2:
|
||||||
|
raise httpx.ConnectError("Failed to connect")
|
||||||
|
return None
|
||||||
|
|
||||||
|
create_many_mock = AsyncMock(side_effect=create_many_side_effect)
|
||||||
|
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
await update_spend(prisma_client, None, proxy_logging_obj)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert create_many_mock.call_count == 6 # 4 batches + 2 retries for failed batch
|
||||||
|
|
||||||
|
# Verify all batches were processed
|
||||||
|
all_processed_logs = []
|
||||||
|
for call in create_many_mock.call_args_list:
|
||||||
|
all_processed_logs.extend(call[1]["data"])
|
||||||
|
|
||||||
|
# Verify all IDs were processed
|
||||||
|
processed_ids = {item["id"] for item in all_processed_logs}
|
||||||
|
|
||||||
|
# these should have ids 0-399
|
||||||
|
print("all processed ids", sorted(processed_ids, key=int))
|
||||||
|
expected_ids = {str(i) for i in range(400)}
|
||||||
|
assert processed_ids == expected_ids
|
||||||
|
|
||||||
|
# Verify all logs were cleared from transactions
|
||||||
|
assert len(prisma_client.spend_log_transactions) == 0
|
Loading…
Add table
Add a link
Reference in a new issue