mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Bug Fix + Better Observability) - BudgetResetJob: (#8562)
* use class ResetBudgetJob * refactor reset budget job * update reset_budget job * refactor reset budget job * fix LiteLLM_UserTable * refactor reset budget job * add telemetry for reset budget job * dd - log service success/failure on DD * add detailed reset budget reset info on DD * initialize_scheduled_background_jobs * refactor reset budget job * trigger service failure hook when fails to reset a budget for team, key, user * fix resetBudgetJob * unit testing for ResetBudgetJob * test_duration_in_seconds_basic * testing for triggering service logging * fix logs on test teams fail * remove unused imports * fix import duration in s * duration_in_seconds
This commit is contained in:
parent
a8717ea124
commit
c8d31a209b
11 changed files with 1107 additions and 87 deletions
|
@ -35,12 +35,18 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.datadog import *
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
from ..additional_logging_utils import AdditionalLoggingUtils
|
||||
|
||||
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
|
||||
# max number of logs DD API can accept
|
||||
DD_MAX_BATCH_SIZE = 1000
|
||||
|
||||
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
|
||||
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
|
||||
ServiceTypes.RESET_BUDGET_JOB,
|
||||
]
|
||||
|
||||
|
||||
class DataDogLogger(
|
||||
|
@ -340,18 +346,16 @@ class DataDogLogger(
|
|||
|
||||
- example - Redis is failing / erroring, will be logged on DataDog
|
||||
"""
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
_payload_dict = payload.model_dump()
|
||||
_payload_dict.update(event_metadata or {})
|
||||
_dd_message_str = json.dumps(_payload_dict, default=str)
|
||||
_dd_payload = DatadogPayload(
|
||||
ddsource="litellm",
|
||||
ddtags="",
|
||||
hostname="",
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=_dd_message_str,
|
||||
service="litellm-server",
|
||||
service=self._get_datadog_service(),
|
||||
status=DataDogStatus.WARN,
|
||||
)
|
||||
|
||||
|
@ -377,7 +381,30 @@ class DataDogLogger(
|
|||
|
||||
No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this
|
||||
"""
|
||||
return
|
||||
try:
|
||||
# intentionally done. Don't want to log all service types to DD
|
||||
if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES:
|
||||
return
|
||||
|
||||
_payload_dict = payload.model_dump()
|
||||
_payload_dict.update(event_metadata or {})
|
||||
|
||||
_dd_message_str = json.dumps(_payload_dict, default=str)
|
||||
_dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=_dd_message_str,
|
||||
service=self._get_datadog_service(),
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
|
||||
self.log_queue.append(_dd_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
|
||||
)
|
||||
|
||||
def _create_v0_logging_payload(
|
||||
self,
|
||||
|
|
|
@ -13,7 +13,7 @@ from typing import Tuple
|
|||
|
||||
|
||||
def _extract_from_regex(duration: str) -> Tuple[int, str]:
|
||||
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
|
||||
match = re.match(r"(\d+)(mo|[smhdw]?)", duration)
|
||||
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
@ -42,6 +42,7 @@ def duration_in_seconds(duration: str) -> int:
|
|||
- "<number>m" - minutes
|
||||
- "<number>h" - hours
|
||||
- "<number>d" - days
|
||||
- "<number>w" - weeks
|
||||
- "<number>mo" - months
|
||||
|
||||
Returns time in seconds till when budget needs to be reset
|
||||
|
@ -56,6 +57,8 @@ def duration_in_seconds(duration: str) -> int:
|
|||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
elif unit == "w":
|
||||
return value * 604800
|
||||
elif unit == "mo":
|
||||
now = time.time()
|
||||
current_time = datetime.fromtimestamp(now)
|
||||
|
|
|
@ -1548,6 +1548,8 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
|||
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
|
||||
teams: List[str] = []
|
||||
sso_user_id: Optional[str] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
357
litellm/proxy/common_utils/reset_budget_job.py
Normal file
357
litellm/proxy/common_utils/reset_budget_job.py
Normal file
|
@ -0,0 +1,357 @@
|
|||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_VerificationToken,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class ResetBudgetJob:
|
||||
"""
|
||||
Resets the budget for all the keys, users, and teams that need it
|
||||
"""
|
||||
|
||||
def __init__(self, proxy_logging_obj: ProxyLogging, prisma_client: PrismaClient):
|
||||
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
|
||||
self.prisma_client: PrismaClient = prisma_client
|
||||
|
||||
async def reset_budget(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Gets all the non-expired keys for a db, which need spend to be reset
|
||||
|
||||
Resets their spend
|
||||
|
||||
Updates db
|
||||
"""
|
||||
if self.prisma_client is not None:
|
||||
### RESET KEY BUDGET ###
|
||||
await self.reset_budget_for_litellm_keys()
|
||||
|
||||
### RESET USER BUDGET ###
|
||||
await self.reset_budget_for_litellm_users()
|
||||
|
||||
## Reset Team Budget
|
||||
await self.reset_budget_for_litellm_teams()
|
||||
|
||||
async def reset_budget_for_litellm_keys(self):
|
||||
"""
|
||||
Resets the budget for all the litellm keys
|
||||
|
||||
Catches Exceptions and logs them
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
start_time = time.time()
|
||||
keys_to_reset: Optional[List[LiteLLM_VerificationToken]] = None
|
||||
try:
|
||||
keys_to_reset = await self.prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Keys to reset %s", json.dumps(keys_to_reset, indent=4, default=str)
|
||||
)
|
||||
updated_keys: List[LiteLLM_VerificationToken] = []
|
||||
failed_keys = []
|
||||
if keys_to_reset is not None and len(keys_to_reset) > 0:
|
||||
for key in keys_to_reset:
|
||||
try:
|
||||
updated_key = await ResetBudgetJob._reset_budget_for_key(
|
||||
key=key, current_time=now
|
||||
)
|
||||
if updated_key is not None:
|
||||
updated_keys.append(updated_key)
|
||||
else:
|
||||
failed_keys.append(
|
||||
{"key": key, "error": "Returned None without exception"}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_keys.append({"key": key, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for key: %s", key
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated keys %s", json.dumps(updated_keys, indent=4, default=str)
|
||||
)
|
||||
|
||||
if updated_keys:
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_keys,
|
||||
table_name="key",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_keys) > 0: # If any keys failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_keys)} keys: {json.dumps(failed_keys, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_keys",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
|
||||
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
|
||||
"num_keys_updated": len(updated_keys),
|
||||
"keys_updated": json.dumps(updated_keys, indent=4, default=str),
|
||||
"num_keys_failed": len(failed_keys),
|
||||
"keys_failed": json.dumps(failed_keys, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_keys",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
|
||||
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for keys: %s", e)
|
||||
|
||||
async def reset_budget_for_litellm_users(self):
|
||||
"""
|
||||
Resets the budget for all LiteLLM Internal Users if their budget has expired
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
start_time = time.time()
|
||||
users_to_reset: Optional[List[LiteLLM_UserTable]] = None
|
||||
try:
|
||||
users_to_reset = await self.prisma_client.get_data(
|
||||
table_name="user", query_type="find_all", reset_at=now
|
||||
)
|
||||
updated_users: List[LiteLLM_UserTable] = []
|
||||
failed_users = []
|
||||
if users_to_reset is not None and len(users_to_reset) > 0:
|
||||
for user in users_to_reset:
|
||||
try:
|
||||
updated_user = await ResetBudgetJob._reset_budget_for_user(
|
||||
user=user, current_time=now
|
||||
)
|
||||
if updated_user is not None:
|
||||
updated_users.append(updated_user)
|
||||
else:
|
||||
failed_users.append(
|
||||
{
|
||||
"user": user,
|
||||
"error": "Returned None without exception",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_users.append({"user": user, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for user: %s", user
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated users %s", json.dumps(updated_users, indent=4, default=str)
|
||||
)
|
||||
if updated_users:
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_users,
|
||||
table_name="user",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_users) > 0: # If any users failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_users)} users: {json.dumps(failed_users, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_users",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_users_found": len(users_to_reset) if users_to_reset else 0,
|
||||
"users_found": json.dumps(
|
||||
users_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_users_updated": len(updated_users),
|
||||
"users_updated": json.dumps(
|
||||
updated_users, indent=4, default=str
|
||||
),
|
||||
"num_users_failed": len(failed_users),
|
||||
"users_failed": json.dumps(failed_users, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_users",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_users_found": len(users_to_reset) if users_to_reset else 0,
|
||||
"users_found": json.dumps(
|
||||
users_to_reset, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for users: %s", e)
|
||||
|
||||
async def reset_budget_for_litellm_teams(self):
|
||||
"""
|
||||
Resets the budget for all LiteLLM Internal Teams if their budget has expired
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
start_time = time.time()
|
||||
teams_to_reset: Optional[List[LiteLLM_TeamTable]] = None
|
||||
try:
|
||||
teams_to_reset = await self.prisma_client.get_data(
|
||||
table_name="team", query_type="find_all", reset_at=now
|
||||
)
|
||||
updated_teams: List[LiteLLM_TeamTable] = []
|
||||
failed_teams = []
|
||||
if teams_to_reset is not None and len(teams_to_reset) > 0:
|
||||
for team in teams_to_reset:
|
||||
try:
|
||||
updated_team = await ResetBudgetJob._reset_budget_for_team(
|
||||
team=team, current_time=now
|
||||
)
|
||||
if updated_team is not None:
|
||||
updated_teams.append(updated_team)
|
||||
else:
|
||||
failed_teams.append(
|
||||
{
|
||||
"team": team,
|
||||
"error": "Returned None without exception",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_teams.append({"team": team, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for team: %s", team
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated teams %s", json.dumps(updated_teams, indent=4, default=str)
|
||||
)
|
||||
if updated_teams:
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_teams,
|
||||
table_name="team",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_teams) > 0: # If any teams failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_teams)} teams: {json.dumps(failed_teams, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_teams",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
|
||||
"teams_found": json.dumps(
|
||||
teams_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_teams_updated": len(updated_teams),
|
||||
"teams_updated": json.dumps(
|
||||
updated_teams, indent=4, default=str
|
||||
),
|
||||
"num_teams_failed": len(failed_teams),
|
||||
"teams_failed": json.dumps(failed_teams, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_teams",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
|
||||
"teams_found": json.dumps(
|
||||
teams_to_reset, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for teams: %s", e)
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_common(
|
||||
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
|
||||
current_time: datetime,
|
||||
item_type: str,
|
||||
) -> Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken]:
|
||||
"""
|
||||
Common logic for resetting budget for a team, user, or key
|
||||
"""
|
||||
try:
|
||||
item.spend = 0.0
|
||||
if hasattr(item, "budget_duration") and item.budget_duration is not None:
|
||||
duration_s = duration_in_seconds(duration=item.budget_duration)
|
||||
item.budget_reset_at = current_time + timedelta(seconds=duration_s)
|
||||
return item
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error resetting budget for %s: %s. Item: %s", item_type, e, item
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_team(
|
||||
team: LiteLLM_TeamTable, current_time: datetime
|
||||
) -> Optional[LiteLLM_TeamTable]:
|
||||
result = await ResetBudgetJob._reset_budget_common(team, current_time, "team")
|
||||
return result if isinstance(result, LiteLLM_TeamTable) else None
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_user(
|
||||
user: LiteLLM_UserTable, current_time: datetime
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
result = await ResetBudgetJob._reset_budget_common(user, current_time, "user")
|
||||
return result if isinstance(result, LiteLLM_UserTable) else None
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_key(
|
||||
key: LiteLLM_VerificationToken, current_time: datetime
|
||||
) -> Optional[LiteLLM_VerificationToken]:
|
||||
result = await ResetBudgetJob._reset_budget_common(key, current_time, "key")
|
||||
return result if isinstance(result, LiteLLM_VerificationToken) else None
|
|
@ -22,10 +22,10 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
duration_in_seconds,
|
||||
generate_key_helper_fn,
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
|
|||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
|
@ -37,7 +38,6 @@ from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
|||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
_hash_token_if_needed,
|
||||
duration_in_seconds,
|
||||
handle_exception_on_proxy,
|
||||
)
|
||||
from litellm.router import Router
|
||||
|
|
|
@ -159,6 +159,7 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
|||
remove_sensitive_info_from_deployment,
|
||||
)
|
||||
from litellm.proxy.common_utils.proxy_state import ProxyState
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||
|
@ -246,7 +247,6 @@ from litellm.proxy.utils import (
|
|||
get_error_message_str,
|
||||
get_instance_fn,
|
||||
hash_token,
|
||||
reset_budget,
|
||||
update_spend,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
|
||||
|
@ -3250,8 +3250,14 @@ class ProxyStartupEvent:
|
|||
|
||||
### RESET BUDGET ###
|
||||
if general_settings.get("disable_reset_budget", False) is False:
|
||||
budget_reset_job = ResetBudgetJob(
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
scheduler.add_job(
|
||||
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
||||
budget_reset_job.reset_budget,
|
||||
"interval",
|
||||
seconds=interval,
|
||||
)
|
||||
|
||||
### UPDATE SPEND ###
|
||||
|
|
|
@ -13,7 +13,6 @@ from email.mime.multipart import MIMEMultipart
|
|||
from email.mime.text import MIMEText
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload
|
||||
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
CommonProxyErrors,
|
||||
|
@ -49,7 +48,6 @@ from litellm.proxy._types import (
|
|||
CallInfo,
|
||||
LiteLLM_VerificationTokenView,
|
||||
Member,
|
||||
ResetTeamBudgetRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.db.create_views import (
|
||||
|
@ -2363,73 +2361,6 @@ def _hash_token_if_needed(token: str) -> str:
|
|||
return token
|
||||
|
||||
|
||||
async def reset_budget(prisma_client: PrismaClient):
|
||||
"""
|
||||
Gets all the non-expired keys for a db, which need spend to be reset
|
||||
|
||||
Resets their spend
|
||||
|
||||
Updates db
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
### RESET KEY BUDGET ###
|
||||
now = datetime.utcnow()
|
||||
keys_to_reset = await prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
|
||||
if keys_to_reset is not None and len(keys_to_reset) > 0:
|
||||
for key in keys_to_reset:
|
||||
key.spend = 0.0
|
||||
duration_s = duration_in_seconds(duration=key.budget_duration)
|
||||
key.budget_reset_at = now + timedelta(seconds=duration_s)
|
||||
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=keys_to_reset, table_name="key"
|
||||
)
|
||||
|
||||
### RESET USER BUDGET ###
|
||||
now = datetime.utcnow()
|
||||
users_to_reset = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_all", reset_at=now
|
||||
)
|
||||
|
||||
if users_to_reset is not None and len(users_to_reset) > 0:
|
||||
for user in users_to_reset:
|
||||
user.spend = 0.0
|
||||
duration_s = duration_in_seconds(duration=user.budget_duration)
|
||||
user.budget_reset_at = now + timedelta(seconds=duration_s)
|
||||
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=users_to_reset, table_name="user"
|
||||
)
|
||||
|
||||
## Reset Team Budget
|
||||
now = datetime.utcnow()
|
||||
teams_to_reset = await prisma_client.get_data(
|
||||
table_name="team",
|
||||
query_type="find_all",
|
||||
reset_at=now,
|
||||
)
|
||||
|
||||
if teams_to_reset is not None and len(teams_to_reset) > 0:
|
||||
team_reset_requests = []
|
||||
for team in teams_to_reset:
|
||||
duration_s = duration_in_seconds(duration=team.budget_duration)
|
||||
reset_team_budget_request = ResetTeamBudgetRequest(
|
||||
team_id=team.team_id,
|
||||
spend=0.0,
|
||||
budget_reset_at=now + timedelta(seconds=duration_s),
|
||||
updated_at=now,
|
||||
)
|
||||
team_reset_requests.append(reset_team_budget_request)
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=team_reset_requests,
|
||||
table_name="team",
|
||||
)
|
||||
|
||||
|
||||
class ProxyUpdateSpend:
|
||||
@staticmethod
|
||||
async def update_end_user_spend(
|
||||
|
|
|
@ -13,6 +13,7 @@ class ServiceTypes(str, enum.Enum):
|
|||
REDIS = "redis"
|
||||
DB = "postgres"
|
||||
BATCH_WRITE_TO_DB = "batch_write_to_db"
|
||||
RESET_BUDGET_JOB = "reset_budget_job"
|
||||
LITELLM = "self"
|
||||
ROUTER = "router"
|
||||
AUTH = "auth"
|
||||
|
|
687
tests/litellm_utils_tests/test_proxy_budget_reset.py
Normal file
687
tests/litellm_utils_tests/test_proxy_budget_reset.py
Normal file
|
@ -0,0 +1,687 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_TeamTable,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
# Note: In our "fake" items we use dicts with fields that our fake reset functions modify.
|
||||
# In a real-world scenario, these would be instances of LiteLLM_VerificationToken, LiteLLM_UserTable, etc.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_keys_partial_failure():
|
||||
"""
|
||||
Test that if one key fails to reset, the failure for that key does not block processing of the other keys.
|
||||
We simulate two keys where the first fails and the second succeeds.
|
||||
"""
|
||||
# Arrange
|
||||
key1 = {
|
||||
"id": "key1",
|
||||
"spend": 10.0,
|
||||
"budget_duration": 60,
|
||||
} # Will trigger simulated failure
|
||||
key2 = {"id": "key2", "spend": 15.0, "budget_duration": 60} # Should be updated
|
||||
key3 = {"id": "key3", "spend": 20.0, "budget_duration": 60} # Should be updated
|
||||
key4 = {"id": "key4", "spend": 25.0, "budget_duration": 60} # Should be updated
|
||||
key5 = {"id": "key5", "spend": 30.0, "budget_duration": 60} # Should be updated
|
||||
key6 = {"id": "key6", "spend": 35.0, "budget_duration": 60} # Should be updated
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(
|
||||
return_value=[key1, key2, key3, key4, key5, key6]
|
||||
)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
# Using a dummy logging object with async hooks mocked out.
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
async def fake_reset_key(key, current_time):
|
||||
if key["id"] == "key1":
|
||||
# Simulate a failure on key1 (for example, this might be due to an invariant check)
|
||||
raise Exception("Simulated failure for key1")
|
||||
else:
|
||||
# Simulate successful reset modification
|
||||
key["spend"] = 0.0
|
||||
# Compute a new reset time based on the budget duration
|
||||
key["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=key["budget_duration"])
|
||||
).isoformat()
|
||||
return key
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key
|
||||
) as mock_reset_key:
|
||||
# Call the method; even though one key fails, the loop should process both
|
||||
await job.reset_budget_for_litellm_keys()
|
||||
# Allow any created tasks (logging hooks) to schedule
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Assert that the helper was called for 6 keys
|
||||
assert mock_reset_key.call_count == 6
|
||||
|
||||
# Assert that update_data was called once with a list containing all 6 keys
|
||||
prisma_client.update_data.assert_awaited_once()
|
||||
update_call = prisma_client.update_data.call_args
|
||||
assert update_call.kwargs.get("table_name") == "key"
|
||||
updated_keys = update_call.kwargs.get("data_list", [])
|
||||
assert len(updated_keys) == 5
|
||||
assert updated_keys[0]["id"] == "key2"
|
||||
assert updated_keys[1]["id"] == "key3"
|
||||
assert updated_keys[2]["id"] == "key4"
|
||||
assert updated_keys[3]["id"] == "key5"
|
||||
assert updated_keys[4]["id"] == "key6"
|
||||
|
||||
# Verify that the failure logging hook was scheduled (due to the failure for key1)
|
||||
failure_hook_calls = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
|
||||
)
|
||||
# There should be one failure hook call for keys (with call_type "reset_budget_keys")
|
||||
assert any(
|
||||
call.kwargs.get("call_type") == "reset_budget_keys"
|
||||
for call in failure_hook_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_users_partial_failure():
|
||||
"""
|
||||
Test that if one user fails to reset, the reset loop still processes the other users.
|
||||
We simulate two users where the first fails and the second is updated.
|
||||
"""
|
||||
user1 = {
|
||||
"id": "user1",
|
||||
"spend": 20.0,
|
||||
"budget_duration": 120,
|
||||
} # Will trigger simulated failure
|
||||
user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Should be updated
|
||||
user3 = {"id": "user3", "spend": 30.0, "budget_duration": 120} # Should be updated
|
||||
user4 = {"id": "user4", "spend": 35.0, "budget_duration": 120} # Should be updated
|
||||
user5 = {"id": "user5", "spend": 40.0, "budget_duration": 120} # Should be updated
|
||||
user6 = {"id": "user6", "spend": 45.0, "budget_duration": 120} # Should be updated
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(
|
||||
return_value=[user1, user2, user3, user4, user5, user6]
|
||||
)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_user(user, current_time):
|
||||
if user["id"] == "user1":
|
||||
raise Exception("Simulated failure for user1")
|
||||
else:
|
||||
user["spend"] = 0.0
|
||||
user["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=user["budget_duration"])
|
||||
).isoformat()
|
||||
return user
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user
|
||||
) as mock_reset_user:
|
||||
await job.reset_budget_for_litellm_users()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert mock_reset_user.call_count == 6
|
||||
prisma_client.update_data.assert_awaited_once()
|
||||
update_call = prisma_client.update_data.call_args
|
||||
assert update_call.kwargs.get("table_name") == "user"
|
||||
updated_users = update_call.kwargs.get("data_list", [])
|
||||
assert len(updated_users) == 5
|
||||
assert updated_users[0]["id"] == "user2"
|
||||
assert updated_users[1]["id"] == "user3"
|
||||
assert updated_users[2]["id"] == "user4"
|
||||
assert updated_users[3]["id"] == "user5"
|
||||
assert updated_users[4]["id"] == "user6"
|
||||
|
||||
failure_hook_calls = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
|
||||
)
|
||||
assert any(
|
||||
call.kwargs.get("call_type") == "reset_budget_users"
|
||||
for call in failure_hook_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_teams_partial_failure():
|
||||
"""
|
||||
Test that if one team fails to reset, the loop processes both teams and only updates the ones that succeeded.
|
||||
We simulate two teams where the first fails and the second is updated.
|
||||
"""
|
||||
team1 = {
|
||||
"id": "team1",
|
||||
"spend": 30.0,
|
||||
"budget_duration": 180,
|
||||
} # Will trigger simulated failure
|
||||
team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180} # Should be updated
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=[team1, team2])
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_team(team, current_time):
|
||||
if team["id"] == "team1":
|
||||
raise Exception("Simulated failure for team1")
|
||||
else:
|
||||
team["spend"] = 0.0
|
||||
team["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=team["budget_duration"])
|
||||
).isoformat()
|
||||
return team
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team
|
||||
) as mock_reset_team:
|
||||
await job.reset_budget_for_litellm_teams()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert mock_reset_team.call_count == 2
|
||||
prisma_client.update_data.assert_awaited_once()
|
||||
update_call = prisma_client.update_data.call_args
|
||||
assert update_call.kwargs.get("table_name") == "team"
|
||||
updated_teams = update_call.kwargs.get("data_list", [])
|
||||
assert len(updated_teams) == 1
|
||||
assert updated_teams[0]["id"] == "team2"
|
||||
|
||||
failure_hook_calls = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
|
||||
)
|
||||
assert any(
|
||||
call.kwargs.get("call_type") == "reset_budget_teams"
|
||||
for call in failure_hook_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_continues_other_categories_on_failure():
|
||||
"""
|
||||
Test that executing the overall reset_budget() method continues to process keys, users, and teams,
|
||||
even if one of the sub-categories (here, users) experiences a partial failure.
|
||||
|
||||
In this simulation:
|
||||
- All keys are processed successfully.
|
||||
- One of the two users fails.
|
||||
- All teams are processed successfully.
|
||||
|
||||
We then assert that:
|
||||
- update_data is called for each category with the correctly updated items.
|
||||
- Each get_data call is made (indicating that one failing category did not abort the others).
|
||||
"""
|
||||
# Arrange dummy items for each table
|
||||
key1 = {"id": "key1", "spend": 10.0, "budget_duration": 60}
|
||||
key2 = {"id": "key2", "spend": 15.0, "budget_duration": 60}
|
||||
user1 = {
|
||||
"id": "user1",
|
||||
"spend": 20.0,
|
||||
"budget_duration": 120,
|
||||
} # Will fail in user reset
|
||||
user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds
|
||||
team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180}
|
||||
team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
|
||||
async def fake_get_data(*, table_name, query_type, **kwargs):
|
||||
if table_name == "key":
|
||||
return [key1, key2]
|
||||
elif table_name == "user":
|
||||
return [user1, user2]
|
||||
elif table_name == "team":
|
||||
return [team1, team2]
|
||||
return []
|
||||
|
||||
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_key(key, current_time):
|
||||
key["spend"] = 0.0
|
||||
key["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=key["budget_duration"])
|
||||
).isoformat()
|
||||
return key
|
||||
|
||||
async def fake_reset_user(user, current_time):
|
||||
if user["id"] == "user1":
|
||||
raise Exception("Simulated failure for user1")
|
||||
user["spend"] = 0.0
|
||||
user["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=user["budget_duration"])
|
||||
).isoformat()
|
||||
return user
|
||||
|
||||
async def fake_reset_team(team, current_time):
|
||||
team["spend"] = 0.0
|
||||
team["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=team["budget_duration"])
|
||||
).isoformat()
|
||||
return team
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key
|
||||
) as mock_reset_key, patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user
|
||||
) as mock_reset_user, patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team
|
||||
) as mock_reset_team:
|
||||
# Call the overall reset_budget method.
|
||||
await job.reset_budget()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify that get_data was called for each table. We can check the table names across calls.
|
||||
called_tables = {
|
||||
call.kwargs.get("table_name") for call in prisma_client.get_data.await_args_list
|
||||
}
|
||||
assert called_tables == {"key", "user", "team"}
|
||||
|
||||
# Verify that update_data was called three times (one per category)
|
||||
assert prisma_client.update_data.await_count == 3
|
||||
calls = prisma_client.update_data.await_args_list
|
||||
|
||||
# Check keys update: both keys succeed.
|
||||
keys_call = calls[0]
|
||||
assert keys_call.kwargs.get("table_name") == "key"
|
||||
assert len(keys_call.kwargs.get("data_list", [])) == 2
|
||||
|
||||
# Check users update: only user2 succeeded.
|
||||
users_call = calls[1]
|
||||
assert users_call.kwargs.get("table_name") == "user"
|
||||
users_updated = users_call.kwargs.get("data_list", [])
|
||||
assert len(users_updated) == 1
|
||||
assert users_updated[0]["id"] == "user2"
|
||||
|
||||
# Check teams update: both teams succeed.
|
||||
teams_call = calls[2]
|
||||
assert teams_call.kwargs.get("table_name") == "team"
|
||||
assert len(teams_call.kwargs.get("data_list", [])) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional tests for service logger behavior (keys, users, teams)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_keys_success():
|
||||
"""
|
||||
Test that when resetting keys succeeds (all keys are updated) the service
|
||||
logger success hook is called with the correct event metadata and no exception is logged.
|
||||
"""
|
||||
keys = [
|
||||
{"id": "key1", "spend": 10.0, "budget_duration": 60},
|
||||
{"id": "key2", "spend": 15.0, "budget_duration": 60},
|
||||
]
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=keys)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_key(key, current_time):
|
||||
key["spend"] = 0.0
|
||||
key["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=key["budget_duration"])
|
||||
).isoformat()
|
||||
return key
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_key",
|
||||
side_effect=fake_reset_key,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_keys()
|
||||
# Allow async logging task to complete
|
||||
await asyncio.sleep(0.1)
|
||||
mock_verbose_exc.assert_not_called()
|
||||
|
||||
# Verify success hook call
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
)
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_keys_found") == len(keys)
|
||||
assert event_metadata.get("num_keys_updated") == len(keys)
|
||||
assert event_metadata.get("num_keys_failed") == 0
|
||||
# Failure hook should not be executed.
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_keys_failure():
|
||||
"""
|
||||
Test that when a key reset fails the service logger failure hook is called,
|
||||
the event metadata reflects the number of keys processed, and that the verbose
|
||||
logger exception is called.
|
||||
"""
|
||||
keys = [
|
||||
{"id": "key1", "spend": 10.0, "budget_duration": 60},
|
||||
{"id": "key2", "spend": 15.0, "budget_duration": 60},
|
||||
]
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=keys)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_key(key, current_time):
|
||||
if key["id"] == "key1":
|
||||
raise Exception("Simulated failure for key1")
|
||||
key["spend"] = 0.0
|
||||
key["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=key["budget_duration"])
|
||||
).isoformat()
|
||||
return key
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_key",
|
||||
side_effect=fake_reset_key,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_keys()
|
||||
await asyncio.sleep(0.1)
|
||||
# Expect at least one exception logged (the inner error and the outer catch)
|
||||
assert mock_verbose_exc.call_count >= 1
|
||||
# Verify exception was logged with correct message
|
||||
assert any(
|
||||
"Failed to reset budget for key" in str(call.args)
|
||||
for call in mock_verbose_exc.call_args_list
|
||||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
)
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_keys_found") == len(keys)
|
||||
keys_found_str = event_metadata.get("keys_found", "")
|
||||
assert "key1" in keys_found_str
|
||||
# Success hook should not be called.
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_users_success():
|
||||
"""
|
||||
Test that when resetting users succeeds the service logger success hook is called with
|
||||
the correct metadata and no exception is logged.
|
||||
"""
|
||||
users = [
|
||||
{"id": "user1", "spend": 20.0, "budget_duration": 120},
|
||||
{"id": "user2", "spend": 25.0, "budget_duration": 120},
|
||||
]
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=users)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_user(user, current_time):
|
||||
user["spend"] = 0.0
|
||||
user["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=user["budget_duration"])
|
||||
).isoformat()
|
||||
return user
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_user",
|
||||
side_effect=fake_reset_user,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_users()
|
||||
await asyncio.sleep(0.1)
|
||||
mock_verbose_exc.assert_not_called()
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
)
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_users_found") == len(users)
|
||||
assert event_metadata.get("num_users_updated") == len(users)
|
||||
assert event_metadata.get("num_users_failed") == 0
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_users_failure():
|
||||
"""
|
||||
Test that a failure during user reset calls the failure hook with appropriate metadata,
|
||||
logs the exception, and does not call the success hook.
|
||||
"""
|
||||
users = [
|
||||
{"id": "user1", "spend": 20.0, "budget_duration": 120},
|
||||
{"id": "user2", "spend": 25.0, "budget_duration": 120},
|
||||
]
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=users)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_user(user, current_time):
|
||||
if user["id"] == "user1":
|
||||
raise Exception("Simulated failure for user1")
|
||||
user["spend"] = 0.0
|
||||
user["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=user["budget_duration"])
|
||||
).isoformat()
|
||||
return user
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_user",
|
||||
side_effect=fake_reset_user,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_users()
|
||||
await asyncio.sleep(0.1)
|
||||
# Verify exception logging
|
||||
assert mock_verbose_exc.call_count >= 1
|
||||
# Verify exception was logged with correct message
|
||||
assert any(
|
||||
"Failed to reset budget for user" in str(call.args)
|
||||
for call in mock_verbose_exc.call_args_list
|
||||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
)
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_users_found") == len(users)
|
||||
users_found_str = event_metadata.get("users_found", "")
|
||||
assert "user1" in users_found_str
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_teams_success():
|
||||
"""
|
||||
Test that when resetting teams is successful the service logger success hook is called with
|
||||
the proper metadata and nothing is logged as an exception.
|
||||
"""
|
||||
teams = [
|
||||
{"id": "team1", "spend": 30.0, "budget_duration": 180},
|
||||
{"id": "team2", "spend": 35.0, "budget_duration": 180},
|
||||
]
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=teams)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_team(team, current_time):
|
||||
team["spend"] = 0.0
|
||||
team["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=team["budget_duration"])
|
||||
).isoformat()
|
||||
return team
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_team",
|
||||
side_effect=fake_reset_team,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_teams()
|
||||
await asyncio.sleep(0.1)
|
||||
mock_verbose_exc.assert_not_called()
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
)
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_teams_found") == len(teams)
|
||||
assert event_metadata.get("num_teams_updated") == len(teams)
|
||||
assert event_metadata.get("num_teams_failed") == 0
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_teams_failure():
|
||||
"""
|
||||
Test that a failure during team reset triggers the failure hook with proper metadata,
|
||||
results in an exception log and no success hook call.
|
||||
"""
|
||||
teams = [
|
||||
{"id": "team1", "spend": 30.0, "budget_duration": 180},
|
||||
{"id": "team2", "spend": 35.0, "budget_duration": 180},
|
||||
]
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(return_value=teams)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_team(team, current_time):
|
||||
if team["id"] == "team1":
|
||||
raise Exception("Simulated failure for team1")
|
||||
team["spend"] = 0.0
|
||||
team["budget_reset_at"] = (
|
||||
current_time + timedelta(seconds=team["budget_duration"])
|
||||
).isoformat()
|
||||
return team
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_team",
|
||||
side_effect=fake_reset_team,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_teams()
|
||||
await asyncio.sleep(0.1)
|
||||
# Verify exception logging
|
||||
assert mock_verbose_exc.call_count >= 1
|
||||
# Verify exception was logged with correct message
|
||||
assert any(
|
||||
"Failed to reset budget for team" in str(call.args)
|
||||
for call in mock_verbose_exc.call_args_list
|
||||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
)
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_teams_found") == len(teams)
|
||||
teams_found_str = event_metadata.get("teams_found", "")
|
||||
assert "team1" in teams_found_str
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
|
|
@ -18,9 +18,7 @@ import pytest
|
|||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers
|
||||
from litellm.proxy.utils import (
|
||||
duration_in_seconds,
|
||||
)
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.litellm_core_utils.duration_parser import (
|
||||
get_last_day_of_month,
|
||||
_extract_from_regex,
|
||||
|
@ -721,6 +719,14 @@ def test_duration_in_seconds():
|
|||
assert value - expected_duration < 2
|
||||
|
||||
|
||||
def test_duration_in_seconds_basic():
|
||||
assert duration_in_seconds(duration="3s") == 3
|
||||
assert duration_in_seconds(duration="3m") == 180
|
||||
assert duration_in_seconds(duration="3h") == 10800
|
||||
assert duration_in_seconds(duration="3d") == 259200
|
||||
assert duration_in_seconds(duration="3w") == 1814400
|
||||
|
||||
|
||||
def test_get_llm_provider_ft_models():
|
||||
"""
|
||||
All ft prefixed models should map to OpenAI
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue