(Bug Fix + Better Observability) - BudgetResetJob: (#8562)

* use class ResetBudgetJob

* refactor reset budget job

* update reset_budget job

* refactor reset budget job

* fix LiteLLM_UserTable

* refactor reset budget job

* add telemetry for reset budget job

* dd - log service success/failure on DD

* add detailed reset budget reset info on DD

* initialize_scheduled_background_jobs

* refactor reset budget job

* trigger service failure hook when fails to reset a budget for team, key, user

* fix resetBudgetJob

* unit testing for ResetBudgetJob

* test_duration_in_seconds_basic

* testing for triggering service logging

* fix logs on test teams fail

* remove unused imports

* fix import duration in s

* duration_in_seconds
This commit is contained in:
Ishaan Jaff 2025-02-15 16:13:08 -08:00 committed by GitHub
parent a8717ea124
commit c8d31a209b
11 changed files with 1107 additions and 87 deletions

View file

@ -35,12 +35,18 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.datadog import *
from litellm.types.services import ServiceLoggerPayload
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import StandardLoggingPayload
from ..additional_logging_utils import AdditionalLoggingUtils
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
# max number of logs DD API can accept
DD_MAX_BATCH_SIZE = 1000
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
ServiceTypes.RESET_BUDGET_JOB,
]
class DataDogLogger(
@ -340,18 +346,16 @@ class DataDogLogger(
- example - Redis is failing / erroring, will be logged on DataDog
"""
try:
import json
_payload_dict = payload.model_dump()
_payload_dict.update(event_metadata or {})
_dd_message_str = json.dumps(_payload_dict, default=str)
_dd_payload = DatadogPayload(
ddsource="litellm",
ddtags="",
hostname="",
ddsource=self._get_datadog_source(),
ddtags=self._get_datadog_tags(),
hostname=self._get_datadog_hostname(),
message=_dd_message_str,
service="litellm-server",
service=self._get_datadog_service(),
status=DataDogStatus.WARN,
)
@ -377,7 +381,30 @@ class DataDogLogger(
No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this
"""
return
try:
# intentionally done. Don't want to log all service types to DD
if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES:
return
_payload_dict = payload.model_dump()
_payload_dict.update(event_metadata or {})
_dd_message_str = json.dumps(_payload_dict, default=str)
_dd_payload = DatadogPayload(
ddsource=self._get_datadog_source(),
ddtags=self._get_datadog_tags(),
hostname=self._get_datadog_hostname(),
message=_dd_message_str,
service=self._get_datadog_service(),
status=DataDogStatus.INFO,
)
self.log_queue.append(_dd_payload)
except Exception as e:
verbose_logger.exception(
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
)
def _create_v0_logging_payload(
self,

View file

@ -13,7 +13,7 @@ from typing import Tuple
def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
match = re.match(r"(\d+)(mo|[smhdw]?)", duration)
if not match:
raise ValueError("Invalid duration format")
@ -42,6 +42,7 @@ def duration_in_seconds(duration: str) -> int:
- "<number>m" - minutes
- "<number>h" - hours
- "<number>d" - days
- "<number>w" - weeks
- "<number>mo" - months
Returns time in seconds till when budget needs to be reset
@ -56,6 +57,8 @@ def duration_in_seconds(duration: str) -> int:
return value * 3600
elif unit == "d":
return value * 86400
elif unit == "w":
return value * 604800
elif unit == "mo":
now = time.time()
current_time = datetime.fromtimestamp(now)

View file

@ -1548,6 +1548,8 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
teams: List[str] = []
sso_user_id: Optional[str] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
@model_validator(mode="before")
@classmethod

View file

@ -0,0 +1,357 @@
import asyncio
import json
import time
from datetime import datetime, timedelta
from typing import List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_UserTable,
LiteLLM_VerificationToken,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.types.services import ServiceTypes
class ResetBudgetJob:
"""
Resets the budget for all the keys, users, and teams that need it
"""
def __init__(self, proxy_logging_obj: ProxyLogging, prisma_client: PrismaClient):
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
self.prisma_client: PrismaClient = prisma_client
async def reset_budget(
self,
):
"""
Gets all the non-expired keys for a db, which need spend to be reset
Resets their spend
Updates db
"""
if self.prisma_client is not None:
### RESET KEY BUDGET ###
await self.reset_budget_for_litellm_keys()
### RESET USER BUDGET ###
await self.reset_budget_for_litellm_users()
## Reset Team Budget
await self.reset_budget_for_litellm_teams()
async def reset_budget_for_litellm_keys(self):
"""
Resets the budget for all the litellm keys
Catches Exceptions and logs them
"""
now = datetime.utcnow()
start_time = time.time()
keys_to_reset: Optional[List[LiteLLM_VerificationToken]] = None
try:
keys_to_reset = await self.prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
verbose_proxy_logger.debug(
"Keys to reset %s", json.dumps(keys_to_reset, indent=4, default=str)
)
updated_keys: List[LiteLLM_VerificationToken] = []
failed_keys = []
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
try:
updated_key = await ResetBudgetJob._reset_budget_for_key(
key=key, current_time=now
)
if updated_key is not None:
updated_keys.append(updated_key)
else:
failed_keys.append(
{"key": key, "error": "Returned None without exception"}
)
except Exception as e:
failed_keys.append({"key": key, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for key: %s", key
)
verbose_proxy_logger.debug(
"Updated keys %s", json.dumps(updated_keys, indent=4, default=str)
)
if updated_keys:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_keys,
table_name="key",
)
end_time = time.time()
if len(failed_keys) > 0: # If any keys failed to reset
raise Exception(
f"Failed to reset {len(failed_keys)} keys: {json.dumps(failed_keys, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_keys",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
"num_keys_updated": len(updated_keys),
"keys_updated": json.dumps(updated_keys, indent=4, default=str),
"num_keys_failed": len(failed_keys),
"keys_failed": json.dumps(failed_keys, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_keys",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for keys: %s", e)
async def reset_budget_for_litellm_users(self):
"""
Resets the budget for all LiteLLM Internal Users if their budget has expired
"""
now = datetime.utcnow()
start_time = time.time()
users_to_reset: Optional[List[LiteLLM_UserTable]] = None
try:
users_to_reset = await self.prisma_client.get_data(
table_name="user", query_type="find_all", reset_at=now
)
updated_users: List[LiteLLM_UserTable] = []
failed_users = []
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
try:
updated_user = await ResetBudgetJob._reset_budget_for_user(
user=user, current_time=now
)
if updated_user is not None:
updated_users.append(updated_user)
else:
failed_users.append(
{
"user": user,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_users.append({"user": user, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for user: %s", user
)
verbose_proxy_logger.debug(
"Updated users %s", json.dumps(updated_users, indent=4, default=str)
)
if updated_users:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_users,
table_name="user",
)
end_time = time.time()
if len(failed_users) > 0: # If any users failed to reset
raise Exception(
f"Failed to reset {len(failed_users)} users: {json.dumps(failed_users, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_users",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_users_found": len(users_to_reset) if users_to_reset else 0,
"users_found": json.dumps(
users_to_reset, indent=4, default=str
),
"num_users_updated": len(updated_users),
"users_updated": json.dumps(
updated_users, indent=4, default=str
),
"num_users_failed": len(failed_users),
"users_failed": json.dumps(failed_users, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_users",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_users_found": len(users_to_reset) if users_to_reset else 0,
"users_found": json.dumps(
users_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for users: %s", e)
async def reset_budget_for_litellm_teams(self):
"""
Resets the budget for all LiteLLM Internal Teams if their budget has expired
"""
now = datetime.utcnow()
start_time = time.time()
teams_to_reset: Optional[List[LiteLLM_TeamTable]] = None
try:
teams_to_reset = await self.prisma_client.get_data(
table_name="team", query_type="find_all", reset_at=now
)
updated_teams: List[LiteLLM_TeamTable] = []
failed_teams = []
if teams_to_reset is not None and len(teams_to_reset) > 0:
for team in teams_to_reset:
try:
updated_team = await ResetBudgetJob._reset_budget_for_team(
team=team, current_time=now
)
if updated_team is not None:
updated_teams.append(updated_team)
else:
failed_teams.append(
{
"team": team,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_teams.append({"team": team, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for team: %s", team
)
verbose_proxy_logger.debug(
"Updated teams %s", json.dumps(updated_teams, indent=4, default=str)
)
if updated_teams:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_teams,
table_name="team",
)
end_time = time.time()
if len(failed_teams) > 0: # If any teams failed to reset
raise Exception(
f"Failed to reset {len(failed_teams)} teams: {json.dumps(failed_teams, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_teams",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
"teams_found": json.dumps(
teams_to_reset, indent=4, default=str
),
"num_teams_updated": len(updated_teams),
"teams_updated": json.dumps(
updated_teams, indent=4, default=str
),
"num_teams_failed": len(failed_teams),
"teams_failed": json.dumps(failed_teams, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_teams",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
"teams_found": json.dumps(
teams_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for teams: %s", e)
@staticmethod
async def _reset_budget_common(
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
current_time: datetime,
item_type: str,
) -> Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken]:
"""
Common logic for resetting budget for a team, user, or key
"""
try:
item.spend = 0.0
if hasattr(item, "budget_duration") and item.budget_duration is not None:
duration_s = duration_in_seconds(duration=item.budget_duration)
item.budget_reset_at = current_time + timedelta(seconds=duration_s)
return item
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget for %s: %s. Item: %s", item_type, e, item
)
raise e
@staticmethod
async def _reset_budget_for_team(
team: LiteLLM_TeamTable, current_time: datetime
) -> Optional[LiteLLM_TeamTable]:
result = await ResetBudgetJob._reset_budget_common(team, current_time, "team")
return result if isinstance(result, LiteLLM_TeamTable) else None
@staticmethod
async def _reset_budget_for_user(
user: LiteLLM_UserTable, current_time: datetime
) -> Optional[LiteLLM_UserTable]:
result = await ResetBudgetJob._reset_budget_common(user, current_time, "user")
return result if isinstance(result, LiteLLM_UserTable) else None
@staticmethod
async def _reset_budget_for_key(
key: LiteLLM_VerificationToken, current_time: datetime
) -> Optional[LiteLLM_VerificationToken]:
result = await ResetBudgetJob._reset_budget_common(key, current_time, "key")
return result if isinstance(result, LiteLLM_VerificationToken) else None

View file

@ -22,10 +22,10 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.key_management_endpoints import (
duration_in_seconds,
generate_key_helper_fn,
prepare_metadata_fields,
)

View file

@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
@ -37,7 +38,6 @@ from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import (
PrismaClient,
_hash_token_if_needed,
duration_in_seconds,
handle_exception_on_proxy,
)
from litellm.router import Router

View file

@ -159,6 +159,7 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
)
from litellm.proxy.common_utils.proxy_state import ProxyState
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
@ -246,7 +247,6 @@ from litellm.proxy.utils import (
get_error_message_str,
get_instance_fn,
hash_token,
reset_budget,
update_spend,
)
from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
@ -3250,8 +3250,14 @@ class ProxyStartupEvent:
### RESET BUDGET ###
if general_settings.get("disable_reset_budget", False) is False:
budget_reset_job = ResetBudgetJob(
proxy_logging_obj=proxy_logging_obj,
prisma_client=prisma_client,
)
scheduler.add_job(
reset_budget, "interval", seconds=interval, args=[prisma_client]
budget_reset_job.reset_budget,
"interval",
seconds=interval,
)
### UPDATE SPEND ###

View file

@ -13,7 +13,6 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
CommonProxyErrors,
@ -49,7 +48,6 @@ from litellm.proxy._types import (
CallInfo,
LiteLLM_VerificationTokenView,
Member,
ResetTeamBudgetRequest,
UserAPIKeyAuth,
)
from litellm.proxy.db.create_views import (
@ -2363,73 +2361,6 @@ def _hash_token_if_needed(token: str) -> str:
return token
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need spend to be reset
Resets their spend
Updates db
"""
if prisma_client is not None:
### RESET KEY BUDGET ###
now = datetime.utcnow()
keys_to_reset = await prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
key.spend = 0.0
duration_s = duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
query_type="update_many", data_list=keys_to_reset, table_name="key"
)
### RESET USER BUDGET ###
now = datetime.utcnow()
users_to_reset = await prisma_client.get_data(
table_name="user", query_type="find_all", reset_at=now
)
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
user.spend = 0.0
duration_s = duration_in_seconds(duration=user.budget_duration)
user.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
query_type="update_many", data_list=users_to_reset, table_name="user"
)
## Reset Team Budget
now = datetime.utcnow()
teams_to_reset = await prisma_client.get_data(
table_name="team",
query_type="find_all",
reset_at=now,
)
if teams_to_reset is not None and len(teams_to_reset) > 0:
team_reset_requests = []
for team in teams_to_reset:
duration_s = duration_in_seconds(duration=team.budget_duration)
reset_team_budget_request = ResetTeamBudgetRequest(
team_id=team.team_id,
spend=0.0,
budget_reset_at=now + timedelta(seconds=duration_s),
updated_at=now,
)
team_reset_requests.append(reset_team_budget_request)
await prisma_client.update_data(
query_type="update_many",
data_list=team_reset_requests,
table_name="team",
)
class ProxyUpdateSpend:
@staticmethod
async def update_end_user_spend(

View file

@ -13,6 +13,7 @@ class ServiceTypes(str, enum.Enum):
REDIS = "redis"
DB = "postgres"
BATCH_WRITE_TO_DB = "batch_write_to_db"
RESET_BUDGET_JOB = "reset_budget_job"
LITELLM = "self"
ROUTER = "router"
AUTH = "auth"

View file

@ -0,0 +1,687 @@
import os
import sys
import time
import traceback
import uuid
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dotenv import load_dotenv
import json
import asyncio
load_dotenv()
import os
import tempfile
from uuid import uuid4
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy._types import (
LiteLLM_VerificationToken,
LiteLLM_UserTable,
LiteLLM_TeamTable,
)
from litellm.types.services import ServiceTypes
# Note: In our "fake" items we use dicts with fields that our fake reset functions modify.
# In a real-world scenario, these would be instances of LiteLLM_VerificationToken, LiteLLM_UserTable, etc.
@pytest.mark.asyncio
async def test_reset_budget_keys_partial_failure():
"""
Test that if one key fails to reset, the failure for that key does not block processing of the other keys.
We simulate two keys where the first fails and the second succeeds.
"""
# Arrange
key1 = {
"id": "key1",
"spend": 10.0,
"budget_duration": 60,
} # Will trigger simulated failure
key2 = {"id": "key2", "spend": 15.0, "budget_duration": 60} # Should be updated
key3 = {"id": "key3", "spend": 20.0, "budget_duration": 60} # Should be updated
key4 = {"id": "key4", "spend": 25.0, "budget_duration": 60} # Should be updated
key5 = {"id": "key5", "spend": 30.0, "budget_duration": 60} # Should be updated
key6 = {"id": "key6", "spend": 35.0, "budget_duration": 60} # Should be updated
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(
return_value=[key1, key2, key3, key4, key5, key6]
)
prisma_client.update_data = AsyncMock()
# Using a dummy logging object with async hooks mocked out.
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
now = datetime.utcnow()
async def fake_reset_key(key, current_time):
if key["id"] == "key1":
# Simulate a failure on key1 (for example, this might be due to an invariant check)
raise Exception("Simulated failure for key1")
else:
# Simulate successful reset modification
key["spend"] = 0.0
# Compute a new reset time based on the budget duration
key["budget_reset_at"] = (
current_time + timedelta(seconds=key["budget_duration"])
).isoformat()
return key
with patch.object(
ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key
) as mock_reset_key:
# Call the method; even though one key fails, the loop should process both
await job.reset_budget_for_litellm_keys()
# Allow any created tasks (logging hooks) to schedule
await asyncio.sleep(0.1)
# Assert that the helper was called for 6 keys
assert mock_reset_key.call_count == 6
# Assert that update_data was called once with a list containing all 6 keys
prisma_client.update_data.assert_awaited_once()
update_call = prisma_client.update_data.call_args
assert update_call.kwargs.get("table_name") == "key"
updated_keys = update_call.kwargs.get("data_list", [])
assert len(updated_keys) == 5
assert updated_keys[0]["id"] == "key2"
assert updated_keys[1]["id"] == "key3"
assert updated_keys[2]["id"] == "key4"
assert updated_keys[3]["id"] == "key5"
assert updated_keys[4]["id"] == "key6"
# Verify that the failure logging hook was scheduled (due to the failure for key1)
failure_hook_calls = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
)
# There should be one failure hook call for keys (with call_type "reset_budget_keys")
assert any(
call.kwargs.get("call_type") == "reset_budget_keys"
for call in failure_hook_calls
)
@pytest.mark.asyncio
async def test_reset_budget_users_partial_failure():
"""
Test that if one user fails to reset, the reset loop still processes the other users.
We simulate two users where the first fails and the second is updated.
"""
user1 = {
"id": "user1",
"spend": 20.0,
"budget_duration": 120,
} # Will trigger simulated failure
user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Should be updated
user3 = {"id": "user3", "spend": 30.0, "budget_duration": 120} # Should be updated
user4 = {"id": "user4", "spend": 35.0, "budget_duration": 120} # Should be updated
user5 = {"id": "user5", "spend": 40.0, "budget_duration": 120} # Should be updated
user6 = {"id": "user6", "spend": 45.0, "budget_duration": 120} # Should be updated
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(
return_value=[user1, user2, user3, user4, user5, user6]
)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_user(user, current_time):
if user["id"] == "user1":
raise Exception("Simulated failure for user1")
else:
user["spend"] = 0.0
user["budget_reset_at"] = (
current_time + timedelta(seconds=user["budget_duration"])
).isoformat()
return user
with patch.object(
ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user
) as mock_reset_user:
await job.reset_budget_for_litellm_users()
await asyncio.sleep(0.1)
assert mock_reset_user.call_count == 6
prisma_client.update_data.assert_awaited_once()
update_call = prisma_client.update_data.call_args
assert update_call.kwargs.get("table_name") == "user"
updated_users = update_call.kwargs.get("data_list", [])
assert len(updated_users) == 5
assert updated_users[0]["id"] == "user2"
assert updated_users[1]["id"] == "user3"
assert updated_users[2]["id"] == "user4"
assert updated_users[3]["id"] == "user5"
assert updated_users[4]["id"] == "user6"
failure_hook_calls = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
)
assert any(
call.kwargs.get("call_type") == "reset_budget_users"
for call in failure_hook_calls
)
@pytest.mark.asyncio
async def test_reset_budget_teams_partial_failure():
"""
Test that if one team fails to reset, the loop processes both teams and only updates the ones that succeeded.
We simulate two teams where the first fails and the second is updated.
"""
team1 = {
"id": "team1",
"spend": 30.0,
"budget_duration": 180,
} # Will trigger simulated failure
team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180} # Should be updated
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=[team1, team2])
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_team(team, current_time):
if team["id"] == "team1":
raise Exception("Simulated failure for team1")
else:
team["spend"] = 0.0
team["budget_reset_at"] = (
current_time + timedelta(seconds=team["budget_duration"])
).isoformat()
return team
with patch.object(
ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team
) as mock_reset_team:
await job.reset_budget_for_litellm_teams()
await asyncio.sleep(0.1)
assert mock_reset_team.call_count == 2
prisma_client.update_data.assert_awaited_once()
update_call = prisma_client.update_data.call_args
assert update_call.kwargs.get("table_name") == "team"
updated_teams = update_call.kwargs.get("data_list", [])
assert len(updated_teams) == 1
assert updated_teams[0]["id"] == "team2"
failure_hook_calls = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
)
assert any(
call.kwargs.get("call_type") == "reset_budget_teams"
for call in failure_hook_calls
)
@pytest.mark.asyncio
async def test_reset_budget_continues_other_categories_on_failure():
"""
Test that executing the overall reset_budget() method continues to process keys, users, and teams,
even if one of the sub-categories (here, users) experiences a partial failure.
In this simulation:
- All keys are processed successfully.
- One of the two users fails.
- All teams are processed successfully.
We then assert that:
- update_data is called for each category with the correctly updated items.
- Each get_data call is made (indicating that one failing category did not abort the others).
"""
# Arrange dummy items for each table
key1 = {"id": "key1", "spend": 10.0, "budget_duration": 60}
key2 = {"id": "key2", "spend": 15.0, "budget_duration": 60}
user1 = {
"id": "user1",
"spend": 20.0,
"budget_duration": 120,
} # Will fail in user reset
user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds
team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180}
team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180}
prisma_client = MagicMock()
async def fake_get_data(*, table_name, query_type, **kwargs):
if table_name == "key":
return [key1, key2]
elif table_name == "user":
return [user1, user2]
elif table_name == "team":
return [team1, team2]
return []
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_key(key, current_time):
key["spend"] = 0.0
key["budget_reset_at"] = (
current_time + timedelta(seconds=key["budget_duration"])
).isoformat()
return key
async def fake_reset_user(user, current_time):
if user["id"] == "user1":
raise Exception("Simulated failure for user1")
user["spend"] = 0.0
user["budget_reset_at"] = (
current_time + timedelta(seconds=user["budget_duration"])
).isoformat()
return user
async def fake_reset_team(team, current_time):
team["spend"] = 0.0
team["budget_reset_at"] = (
current_time + timedelta(seconds=team["budget_duration"])
).isoformat()
return team
with patch.object(
ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key
) as mock_reset_key, patch.object(
ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user
) as mock_reset_user, patch.object(
ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team
) as mock_reset_team:
# Call the overall reset_budget method.
await job.reset_budget()
await asyncio.sleep(0.1)
# Verify that get_data was called for each table. We can check the table names across calls.
called_tables = {
call.kwargs.get("table_name") for call in prisma_client.get_data.await_args_list
}
assert called_tables == {"key", "user", "team"}
# Verify that update_data was called three times (one per category)
assert prisma_client.update_data.await_count == 3
calls = prisma_client.update_data.await_args_list
# Check keys update: both keys succeed.
keys_call = calls[0]
assert keys_call.kwargs.get("table_name") == "key"
assert len(keys_call.kwargs.get("data_list", [])) == 2
# Check users update: only user2 succeeded.
users_call = calls[1]
assert users_call.kwargs.get("table_name") == "user"
users_updated = users_call.kwargs.get("data_list", [])
assert len(users_updated) == 1
assert users_updated[0]["id"] == "user2"
# Check teams update: both teams succeed.
teams_call = calls[2]
assert teams_call.kwargs.get("table_name") == "team"
assert len(teams_call.kwargs.get("data_list", [])) == 2
# ---------------------------------------------------------------------------
# Additional tests for service logger behavior (keys, users, teams)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_logger_keys_success():
"""
Test that when resetting keys succeeds (all keys are updated) the service
logger success hook is called with the correct event metadata and no exception is logged.
"""
keys = [
{"id": "key1", "spend": 10.0, "budget_duration": 60},
{"id": "key2", "spend": 15.0, "budget_duration": 60},
]
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=keys)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_key(key, current_time):
key["spend"] = 0.0
key["budget_reset_at"] = (
current_time + timedelta(seconds=key["budget_duration"])
).isoformat()
return key
with patch.object(
ResetBudgetJob,
"_reset_budget_for_key",
side_effect=fake_reset_key,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_keys()
# Allow async logging task to complete
await asyncio.sleep(0.1)
mock_verbose_exc.assert_not_called()
# Verify success hook call
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
)
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_keys_found") == len(keys)
assert event_metadata.get("num_keys_updated") == len(keys)
assert event_metadata.get("num_keys_failed") == 0
# Failure hook should not be executed.
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_keys_failure():
"""
Test that when a key reset fails the service logger failure hook is called,
the event metadata reflects the number of keys processed, and that the verbose
logger exception is called.
"""
keys = [
{"id": "key1", "spend": 10.0, "budget_duration": 60},
{"id": "key2", "spend": 15.0, "budget_duration": 60},
]
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=keys)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_key(key, current_time):
if key["id"] == "key1":
raise Exception("Simulated failure for key1")
key["spend"] = 0.0
key["budget_reset_at"] = (
current_time + timedelta(seconds=key["budget_duration"])
).isoformat()
return key
with patch.object(
ResetBudgetJob,
"_reset_budget_for_key",
side_effect=fake_reset_key,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_keys()
await asyncio.sleep(0.1)
# Expect at least one exception logged (the inner error and the outer catch)
assert mock_verbose_exc.call_count >= 1
# Verify exception was logged with correct message
assert any(
"Failed to reset budget for key" in str(call.args)
for call in mock_verbose_exc.call_args_list
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
)
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_keys_found") == len(keys)
keys_found_str = event_metadata.get("keys_found", "")
assert "key1" in keys_found_str
# Success hook should not be called.
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_users_success():
"""
Test that when resetting users succeeds the service logger success hook is called with
the correct metadata and no exception is logged.
"""
users = [
{"id": "user1", "spend": 20.0, "budget_duration": 120},
{"id": "user2", "spend": 25.0, "budget_duration": 120},
]
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=users)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_user(user, current_time):
user["spend"] = 0.0
user["budget_reset_at"] = (
current_time + timedelta(seconds=user["budget_duration"])
).isoformat()
return user
with patch.object(
ResetBudgetJob,
"_reset_budget_for_user",
side_effect=fake_reset_user,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_users()
await asyncio.sleep(0.1)
mock_verbose_exc.assert_not_called()
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
)
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_users_found") == len(users)
assert event_metadata.get("num_users_updated") == len(users)
assert event_metadata.get("num_users_failed") == 0
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_users_failure():
"""
Test that a failure during user reset calls the failure hook with appropriate metadata,
logs the exception, and does not call the success hook.
"""
users = [
{"id": "user1", "spend": 20.0, "budget_duration": 120},
{"id": "user2", "spend": 25.0, "budget_duration": 120},
]
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=users)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_user(user, current_time):
if user["id"] == "user1":
raise Exception("Simulated failure for user1")
user["spend"] = 0.0
user["budget_reset_at"] = (
current_time + timedelta(seconds=user["budget_duration"])
).isoformat()
return user
with patch.object(
ResetBudgetJob,
"_reset_budget_for_user",
side_effect=fake_reset_user,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_users()
await asyncio.sleep(0.1)
# Verify exception logging
assert mock_verbose_exc.call_count >= 1
# Verify exception was logged with correct message
assert any(
"Failed to reset budget for user" in str(call.args)
for call in mock_verbose_exc.call_args_list
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
)
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_users_found") == len(users)
users_found_str = event_metadata.get("users_found", "")
assert "user1" in users_found_str
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_teams_success():
"""
Test that when resetting teams is successful the service logger success hook is called with
the proper metadata and nothing is logged as an exception.
"""
teams = [
{"id": "team1", "spend": 30.0, "budget_duration": 180},
{"id": "team2", "spend": 35.0, "budget_duration": 180},
]
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=teams)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_team(team, current_time):
team["spend"] = 0.0
team["budget_reset_at"] = (
current_time + timedelta(seconds=team["budget_duration"])
).isoformat()
return team
with patch.object(
ResetBudgetJob,
"_reset_budget_for_team",
side_effect=fake_reset_team,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_teams()
await asyncio.sleep(0.1)
mock_verbose_exc.assert_not_called()
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
)
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_teams_found") == len(teams)
assert event_metadata.get("num_teams_updated") == len(teams)
assert event_metadata.get("num_teams_failed") == 0
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_teams_failure():
"""
Test that a failure during team reset triggers the failure hook with proper metadata,
results in an exception log and no success hook call.
"""
teams = [
{"id": "team1", "spend": 30.0, "budget_duration": 180},
{"id": "team2", "spend": 35.0, "budget_duration": 180},
]
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(return_value=teams)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_team(team, current_time):
if team["id"] == "team1":
raise Exception("Simulated failure for team1")
team["spend"] = 0.0
team["budget_reset_at"] = (
current_time + timedelta(seconds=team["budget_duration"])
).isoformat()
return team
with patch.object(
ResetBudgetJob,
"_reset_budget_for_team",
side_effect=fake_reset_team,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_teams()
await asyncio.sleep(0.1)
# Verify exception logging
assert mock_verbose_exc.call_count >= 1
# Verify exception was logged with correct message
assert any(
"Failed to reset budget for team" in str(call.args)
for call in mock_verbose_exc.call_args_list
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
)
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_teams_found") == len(teams)
teams_found_str = event_metadata.get("teams_found", "")
assert "team1" in teams_found_str
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()

View file

@ -18,9 +18,7 @@ import pytest
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers
from litellm.proxy.utils import (
duration_in_seconds,
)
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.litellm_core_utils.duration_parser import (
get_last_day_of_month,
_extract_from_regex,
@ -721,6 +719,14 @@ def test_duration_in_seconds():
assert value - expected_duration < 2
def test_duration_in_seconds_basic():
assert duration_in_seconds(duration="3s") == 3
assert duration_in_seconds(duration="3m") == 180
assert duration_in_seconds(duration="3h") == 10800
assert duration_in_seconds(duration="3d") == 259200
assert duration_in_seconds(duration="3w") == 1814400
def test_get_llm_provider_ft_models():
"""
All ft prefixed models should map to OpenAI