This commit is contained in:
Laurien 2025-04-24 01:00:22 -07:00 committed by GitHub
commit e6593570a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 1022 additions and 47 deletions

View file

@ -885,6 +885,10 @@ class BudgetNewRequest(LiteLLMPydanticObjectBase):
default=None,
description="Max budget for each model (e.g. {'gpt-4o': {'max_budget': '0.0000001', 'budget_duration': '1d', 'tpm_limit': 1000, 'rpm_limit': 1000}})",
)
budget_reset_at: Optional[datetime] = Field(
default=None,
description="Datetime when the budget is reset",
)
class BudgetRequest(LiteLLMPydanticObjectBase):
@ -1206,10 +1210,18 @@ class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase):
rpm_limit: Optional[int] = None
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable):
"""Represents all params for a LiteLLM_BudgetTable record"""
budget_id: str
created_at: datetime
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
"""
Used to track spend of a user_id within a team_id

View file

@ -1,12 +1,14 @@
import asyncio
import json
import time
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import (
LiteLLM_BudgetTableFull,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
LiteLLM_UserTable,
LiteLLM_VerificationToken,
@ -44,6 +46,141 @@ class ResetBudgetJob:
## Reset Team Budget
await self.reset_budget_for_litellm_teams()
### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ###
await self.reset_budget_for_litellm_endusers()
async def reset_budget_for_litellm_endusers(self):
"""
Resets the budget for all LiteLLM End-Users (Customers) if their budget has expired
The corresponding Budget duration is also updated.
"""
now = datetime.now(timezone.utc)
start_time = time.time()
endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None
budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None
updated_endusers: List[LiteLLM_EndUserTable] = []
failed_endusers = []
try:
budgets_to_reset = await self.prisma_client.get_data(
table_name="budget", query_type="find_all", reset_at=now
)
if budgets_to_reset is not None and len(budgets_to_reset) > 0:
for budget in budgets_to_reset:
budget = await ResetBudgetJob._reset_budget_reset_at_date(
budget, now
)
await self.prisma_client.update_data(
query_type="update_many",
data_list=budgets_to_reset,
table_name="budget",
)
endusers_to_reset = await self.prisma_client.get_data(
table_name="enduser",
query_type="find_all",
budget_id_list=[budget.budget_id for budget in budgets_to_reset],
)
if endusers_to_reset is not None and len(endusers_to_reset) > 0:
for enduser in endusers_to_reset:
try:
updated_enduser = (
await ResetBudgetJob._reset_budget_for_enduser(
enduser=enduser
)
)
if updated_enduser is not None:
updated_endusers.append(updated_enduser)
else:
failed_endusers.append(
{
"enduser": enduser,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_endusers.append({"enduser": enduser, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for enduser: %s", enduser
)
verbose_proxy_logger.debug(
"Updated users %s",
json.dumps(updated_endusers, indent=4, default=str),
)
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_endusers,
table_name="enduser",
)
end_time = time.time()
if len(failed_endusers) > 0: # If any endusers failed to reset
raise Exception(
f"Failed to reset {len(failed_endusers)} endusers: {json.dumps(failed_endusers, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_endusers",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_budgets_found": len(budgets_to_reset)
if budgets_to_reset
else 0,
"budgets_found": json.dumps(
budgets_to_reset, indent=4, default=str
),
"num_endusers_found": len(endusers_to_reset)
if endusers_to_reset
else 0,
"endusers_found": json.dumps(
endusers_to_reset, indent=4, default=str
),
"num_endusers_updated": len(updated_endusers),
"endusers_updated": json.dumps(
updated_endusers, indent=4, default=str
),
"num_endusers_failed": len(failed_endusers),
"endusers_failed": json.dumps(
failed_endusers, indent=4, default=str
),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_endusers",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_budgets_found": len(budgets_to_reset)
if budgets_to_reset
else 0,
"budgets_found": json.dumps(
budgets_to_reset, indent=4, default=str
),
"num_endusers_found": len(endusers_to_reset)
if endusers_to_reset
else 0,
"endusers_found": json.dumps(
endusers_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e)
async def reset_budget_for_litellm_keys(self):
"""
Resets the budget for all the litellm keys
@ -355,6 +492,46 @@ class ResetBudgetJob:
)
return user
@staticmethod
async def _reset_budget_for_enduser(
enduser: LiteLLM_EndUserTable,
) -> Optional[LiteLLM_EndUserTable]:
try:
enduser.spend = 0.0
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget for enduser: %s. Item: %s", e, enduser
)
raise e
return enduser
@staticmethod
async def _reset_budget_reset_at_date(
budget: LiteLLM_BudgetTableFull, current_time: datetime
) -> Optional[LiteLLM_BudgetTableFull]:
try:
if budget.budget_duration is not None:
duration_s = duration_in_seconds(duration=budget.budget_duration)
# Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account
if (
budget.budget_reset_at is None
and budget.created_at + timedelta(seconds=duration_s) > current_time
):
budget.budget_reset_at = budget.created_at + timedelta(
seconds=duration_s
)
else:
budget.budget_reset_at = current_time + timedelta(
seconds=duration_s
)
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget_reset_at for budget: %s. Item: %s", e, budget
)
raise e
return budget
@staticmethod
async def _reset_budget_for_key(
key: LiteLLM_VerificationToken, current_time: datetime

View file

@ -12,8 +12,10 @@ All /budget management endpoints
"""
#### BUDGET TABLE MANAGEMENT ####
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.utils import jsonify_object
@ -51,6 +53,12 @@ async def new_budget(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# if no budget_reset_at date is set, but a budget_duration is given, then set budget_reset_at initially to the first completed duration interval in future
if budget_obj.budget_reset_at is None and budget_obj.budget_duration is not None:
budget_obj.budget_reset_at = datetime.utcnow() + timedelta(
seconds=duration_in_seconds(duration=budget_obj.budget_duration)
)
budget_obj_json = budget_obj.model_dump(exclude_none=True)
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
response = await prisma_client.db.litellm_budgettable.create(

View file

@ -1410,6 +1410,8 @@ class PrismaClient:
"key",
"config",
"spend",
"enduser",
"budget",
"team",
"user_notification",
"combined_view",
@ -1424,6 +1426,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
budget_id_list: Optional[List[str]] = None,
):
args_passed_in = locals()
start_time = time.time()
@ -1603,6 +1606,29 @@ class PrismaClient:
order={"startTime": "desc"},
)
return response
elif table_name == "budget" and reset_at is not None:
if query_type == "find_all":
response = await self.db.litellm_budgettable.find_many(
where={ # type:ignore
"OR": [
{
"AND": [
{"budget_reset_at": None},
{"NOT": {"budget_duration": None}},
]
},
{"budget_reset_at": {"lt": reset_at}},
]
}
)
return response
elif table_name == "enduser" and budget_id_list is not None:
if query_type == "find_all":
response = await self.db.litellm_endusertable.find_many(
where={"budget_id": {"in": budget_id_list}}
)
return response
elif table_name == "team":
if query_type == "find_unique":
response = await self.db.litellm_teamtable.find_unique(
@ -1916,7 +1942,9 @@ class PrismaClient:
user_id: Optional[str] = None,
team_id: Optional[str] = None,
query_type: Literal["update", "update_many"] = "update",
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
table_name: Optional[
Literal["user", "key", "config", "spend", "team", "enduser", "budget"]
] = None,
update_key_values: Optional[dict] = None,
update_key_values_custom_query: Optional[dict] = None,
):
@ -2083,6 +2111,68 @@ class PrismaClient:
verbose_proxy_logger.info(
"\033[91m" + "DB User Table Batch update succeeded" + "\033[0m"
)
elif (
table_name is not None
and table_name == "enduser"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for enduser in data_list:
try:
data_json = self.jsonify_object(
data=enduser.model_dump(exclude_none=True)
)
except Exception:
data_json = self.jsonify_object(data=enduser.dict())
batcher.litellm_endusertable.upsert(
where={"user_id": enduser.user_id}, # type: ignore
data={
"create": {**data_json}, # type: ignore
"update": {
**data_json # type: ignore
}, # just update end-user-specified values, if it already exists
},
)
await batcher.commit()
verbose_proxy_logger.info(
"\033[91m" + "DB End User Table Batch update succeeded" + "\033[0m"
)
elif (
table_name is not None
and table_name == "budget"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for budget in data_list:
try:
data_json = self.jsonify_object(
data=budget.model_dump(exclude_none=True)
)
except Exception:
data_json = self.jsonify_object(data=budget.dict())
batcher.litellm_budgettable.upsert(
where={"budget_id": budget.budget_id}, # type: ignore
data={
"create": {**data_json}, # type: ignore
"update": {
**data_json # type: ignore
}, # just update end-user-specified values, if it already exists
},
)
await batcher.commit()
verbose_proxy_logger.info(
"\033[91m" + "DB Budget Table Batch update succeeded" + "\033[0m"
)
elif (
table_name is not None
and table_name == "team"

View file

@ -1,12 +1,9 @@
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timedelta
from datetime import datetime, timezone
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
@ -20,8 +17,14 @@ from litellm.proxy.utils import ProxyLogging
# Mock classes for testing
class MockPrismaClient:
def __init__(self):
self.data = {"key": [], "user": [], "team": []}
self.updated_data = {"key": [], "user": [], "team": []}
self.data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []}
self.updated_data = {
"key": [],
"user": [],
"team": [],
"budget": [],
"enduser": [],
}
async def get_data(self, table_name, query_type, **kwargs):
return self.data.get(table_name, [])
@ -145,9 +148,48 @@ def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
assert updated_team.budget_reset_at > now
def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.now(timezone.utc)
test_budget = type(
"LiteLLM_BudgetTable",
(),
{
"max_budget": 500.0,
"budget_duration": "1d",
"budget_reset_at": now,
"budget_id": "test-budget-1",
},
)
test_enduser = type(
"LiteLLM_EndUserTable",
(),
{
"spend": 20.0,
"litellm_budget_table": test_budget,
"user_id": "test-enduser-1",
},
)
mock_prisma_client.data["budget"] = [test_budget]
mock_prisma_client.data["enduser"] = [test_enduser]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_endusers())
# Verify results
assert len(mock_prisma_client.updated_data["enduser"]) == 1
assert len(mock_prisma_client.updated_data["budget"]) == 1
updated_enduser = mock_prisma_client.updated_data["enduser"][0]
updated_budget = mock_prisma_client.updated_data["budget"][0]
assert updated_enduser.spend == 0.0
assert updated_budget.budget_reset_at > now
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
now = datetime.now(timezone.utc)
# Create test objects for all three types
test_key = type(
@ -183,9 +225,32 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client):
},
)
test_budget = type(
"LiteLLM_BudgetTable",
(),
{
"max_budget": 500.0,
"budget_duration": "1d",
"budget_reset_at": now,
"budget_id": "test-budget-1",
},
)
test_enduser = type(
"LiteLLM_EndUserTable",
(),
{
"spend": 20.0,
"litellm_budget_table": test_budget,
"user_id": "test-enduser-1",
},
)
mock_prisma_client.data["key"] = [test_key]
mock_prisma_client.data["user"] = [test_user]
mock_prisma_client.data["team"] = [test_team]
mock_prisma_client.data["budget"] = [test_budget]
mock_prisma_client.data["enduser"] = [test_enduser]
# Run the test
asyncio.run(reset_budget_job.reset_budget())
@ -194,8 +259,11 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client):
assert len(mock_prisma_client.updated_data["key"]) == 1
assert len(mock_prisma_client.updated_data["user"]) == 1
assert len(mock_prisma_client.updated_data["team"]) == 1
assert len(mock_prisma_client.updated_data["enduser"]) == 1
assert len(mock_prisma_client.updated_data["budget"]) == 1
# Check that all spends were reset to 0
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
assert mock_prisma_client.updated_data["team"][0].spend == 0.0
assert mock_prisma_client.updated_data["enduser"][0].spend == 0.0

View file

@ -1,32 +1,22 @@
import asyncio
import os
import sys
import time
import traceback
import uuid
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dotenv import load_dotenv
import json
import asyncio
load_dotenv()
import os
import tempfile
from uuid import uuid4
from litellm.proxy._types import LiteLLM_BudgetTableFull
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy._types import (
LiteLLM_VerificationToken,
LiteLLM_UserTable,
LiteLLM_TeamTable,
)
from litellm.types.services import ServiceTypes
# Note: In our "fake" items we use dicts with fields that our fake reset functions modify.
# In a real-world scenario, these would be instances of LiteLLM_VerificationToken, LiteLLM_UserTable, etc.
@ -180,6 +170,106 @@ async def test_reset_budget_users_partial_failure():
)
@pytest.mark.asyncio
async def test_reset_budget_endusers_partial_failure():
"""
Test that if one enduser fails to reset, the reset loop still processes the other endusers.
We simulate six endsers where the first fails and the others are updated.
"""
user1 = {
"user_id": "user1",
"spend": 20.0,
"budget_id": "budget1",
} # Will trigger simulated failure
user2 = {
"user_id": "user2",
"spend": 25.0,
"budget_id": "budget1",
} # Should be updated
user3 = {
"user_id": "user3",
"spend": 30.0,
"budget_id": "budget1",
} # Should be updated
user4 = {
"user_id": "user4",
"spend": 35.0,
"budget_id": "budget1",
} # Should be updated
user5 = {
"user_id": "user5",
"spend": 40.0,
"budget_id": "budget1",
} # Should be updated
user6 = {
"user_id": "user6",
"spend": 45.0,
"budget_id": "budget1",
} # Should be updated
budget1 = LiteLLM_BudgetTableFull(
**{
"budget_id": "budget1",
"max_budget": 65.0,
"budget_duration": "2d",
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
}
)
prisma_client = MagicMock()
async def get_data_mock(table_name, *args, **kwargs):
if table_name == "budget":
return [budget1]
elif table_name == "enduser":
return [user1, user2, user3, user4, user5, user6]
return []
prisma_client.get_data = AsyncMock()
prisma_client.get_data.side_effect = get_data_mock
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_enduser(enduser):
if enduser["user_id"] == "user1":
raise Exception("Simulated failure for user1")
enduser["spend"] = 0.0
return enduser
with patch.object(
ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser
) as mock_reset_enduser:
await job.reset_budget_for_litellm_endusers()
await asyncio.sleep(0.1)
assert mock_reset_enduser.call_count == 6
assert prisma_client.update_data.await_count == 2
update_call = prisma_client.update_data.call_args
assert update_call.kwargs.get("table_name") == "enduser"
updated_users = update_call.kwargs.get("data_list", [])
assert len(updated_users) == 5
assert updated_users[0]["user_id"] == "user2"
assert updated_users[1]["user_id"] == "user3"
assert updated_users[2]["user_id"] == "user4"
assert updated_users[3]["user_id"] == "user5"
assert updated_users[4]["user_id"] == "user6"
failure_hook_calls = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
)
assert any(
call.kwargs.get("call_type") == "reset_budget_endusers"
for call in failure_hook_calls
)
@pytest.mark.asyncio
async def test_reset_budget_teams_partial_failure():
"""
@ -263,6 +353,15 @@ async def test_reset_budget_continues_other_categories_on_failure():
user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds
team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180}
team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180}
enduser1 = {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}
budget1 = LiteLLM_BudgetTableFull(
**{
"budget_id": "budget1",
"max_budget": 65.0,
"budget_duration": "2d",
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
}
)
prisma_client = MagicMock()
@ -273,6 +372,10 @@ async def test_reset_budget_continues_other_categories_on_failure():
return [user1, user2]
elif table_name == "team":
return [team1, team2]
elif table_name == "budget":
return [budget1]
elif table_name == "enduser":
return [enduser1]
return []
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
@ -308,13 +411,19 @@ async def test_reset_budget_continues_other_categories_on_failure():
).isoformat()
return team
async def fake_reset_enduser(enduser):
enduser["spend"] = 0.0
return enduser
with patch.object(
ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key
) as mock_reset_key, patch.object(
ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user
) as mock_reset_user, patch.object(
ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team
) as mock_reset_team:
) as mock_reset_team, patch.object(
ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser
) as mock_reset_enduser:
# Call the overall reset_budget method.
await job.reset_budget()
await asyncio.sleep(0.1)
@ -323,10 +432,10 @@ async def test_reset_budget_continues_other_categories_on_failure():
called_tables = {
call.kwargs.get("table_name") for call in prisma_client.get_data.await_args_list
}
assert called_tables == {"key", "user", "team"}
assert called_tables == {"key", "user", "team", "budget", "enduser"}
# Verify that update_data was called three times (one per category)
assert prisma_client.update_data.await_count == 3
# Verify that update_data was called three times (one per category, enduser update includes two)
assert prisma_client.update_data.await_count == 5
calls = prisma_client.update_data.await_args_list
# Check keys update: both keys succeed.
@ -346,9 +455,14 @@ async def test_reset_budget_continues_other_categories_on_failure():
assert teams_call.kwargs.get("table_name") == "team"
assert len(teams_call.kwargs.get("data_list", [])) == 2
# Check enduser update: enduser succeed.
enduser_call = calls[4]
assert enduser_call.kwargs.get("table_name") == "enduser"
assert len(enduser_call.kwargs.get("data_list", [])) == 1
# ---------------------------------------------------------------------------
# Additional tests for service logger behavior (keys, users, teams)
# Additional tests for service logger behavior (keys, users, teams, endusers)
# ---------------------------------------------------------------------------
@ -395,9 +509,10 @@ async def test_service_logger_keys_success():
# Verify success hook call
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
)
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_keys_found") == len(keys)
assert event_metadata.get("num_keys_updated") == len(keys)
@ -456,9 +571,10 @@ async def test_service_logger_keys_failure():
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
)
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_keys_found") == len(keys)
keys_found_str = event_metadata.get("keys_found", "")
@ -508,9 +624,10 @@ async def test_service_logger_users_success():
mock_verbose_exc.assert_not_called()
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
)
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_users_found") == len(users)
assert event_metadata.get("num_users_updated") == len(users)
@ -567,9 +684,10 @@ async def test_service_logger_users_failure():
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
)
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_users_found") == len(users)
users_found_str = event_metadata.get("users_found", "")
@ -618,9 +736,10 @@ async def test_service_logger_teams_success():
mock_verbose_exc.assert_not_called()
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
)
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_teams_found") == len(teams)
assert event_metadata.get("num_teams_updated") == len(teams)
@ -677,11 +796,156 @@ async def test_service_logger_teams_failure():
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
args, kwargs = (
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
)
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_teams_found") == len(teams)
teams_found_str = event_metadata.get("teams_found", "")
assert "team1" in teams_found_str
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_endusers_success():
"""
Test that when resetting endusers succeeds the service logger success hook is called with
the correct metadata and no exception is logged.
"""
endusers = [
{"user_id": "user1", "spend": 25.0, "budget_id": "budget1"},
{"user_id": "user2", "spend": 25.0, "budget_id": "budget1"},
]
budgets = [
LiteLLM_BudgetTableFull(
**{
"budget_id": "budget1",
"max_budget": 65.0,
"budget_duration": "2d",
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
}
)
]
async def fake_get_data(*, table_name, query_type, **kwargs):
if table_name == "budget":
return budgets
elif table_name == "enduser":
return endusers
return []
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_enduser(enduser):
enduser["spend"] = 0.0
return enduser
with patch.object(
ResetBudgetJob,
"_reset_budget_for_enduser",
side_effect=fake_reset_enduser,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_endusers()
await asyncio.sleep(0.1)
mock_verbose_exc.assert_not_called()
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_budgets_found") == len(budgets)
assert event_metadata.get("num_endusers_found") == len(endusers)
assert event_metadata.get("num_endusers_updated") == len(endusers)
assert event_metadata.get("num_endusers_failed") == 0
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
@pytest.mark.asyncio
async def test_service_logger_users_failure():
"""
Test that a failure during enduser reset calls the failure hook with appropriate metadata,
logs the exception, and does not call the success hook.
"""
endusers = [
{"user_id": "user1", "spend": 25.0, "budget_id": "budget1"},
{"user_id": "user2", "spend": 25.0, "budget_id": "budget1"},
]
budgets = [
LiteLLM_BudgetTableFull(
**{
"budget_id": "budget1",
"max_budget": 65.0,
"budget_duration": "2d",
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
}
)
]
async def fake_get_data(*, table_name, query_type, **kwargs):
if table_name == "budget":
return budgets
elif table_name == "enduser":
return endusers
return []
prisma_client = MagicMock()
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
prisma_client.update_data = AsyncMock()
proxy_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj = MagicMock()
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
async def fake_reset_enduser(enduser):
if enduser["user_id"] == "user1":
raise Exception("Simulated failure for user1")
enduser["spend"] = 0.0
return enduser
with patch.object(
ResetBudgetJob,
"_reset_budget_for_enduser",
side_effect=fake_reset_enduser,
):
with patch(
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
) as mock_verbose_exc:
await job.reset_budget_for_litellm_endusers()
await asyncio.sleep(0.1)
# Verify exception logging
assert mock_verbose_exc.call_count >= 1
# Verify exception was logged with correct message
assert any(
"Failed to reset budget for enduser" in str(call.args)
for call in mock_verbose_exc.call_args_list
)
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
(
args,
kwargs,
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
event_metadata = kwargs.get("event_metadata", {})
assert event_metadata.get("num_budgets_found") == len(budgets)
assert event_metadata.get("num_endusers_found") == len(endusers)
endusers_found_str = event_metadata.get("endusers_found", "")
assert "user1" in endusers_found_str
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()

View file

@ -0,0 +1,266 @@
import asyncio
import os
import sys
import uuid
from datetime import datetime, timedelta, timezone
import aiohttp
import pytest
import pytest_asyncio
from dotenv import load_dotenv
from litellm.caching.caching import DualCache
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.utils import PrismaClient, ProxyLogging
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
load_dotenv()
async def create_budget(session, data):
url = "http://0.0.0.0:4000/budget/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
async with session.post(url, headers=headers, json=data) as response:
assert response.status == 200
response_data = await response.json()
budget_id = response_data["budget_id"]
print(f"Created Budget {budget_id}")
return response_data
async def create_end_user(prisma_client, session, user_id, budget_id, spend=None):
url = "http://0.0.0.0:4000/end_user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"user_id": user_id,
"budget_id": budget_id,
}
async with session.post(url, headers=headers, json=data) as response:
assert response.status == 200
response_data = await response.json()
end_user_id = response_data["user_id"]
print(f"Created End User {end_user_id}")
if spend is not None:
end_users = await prisma_client.get_data(
table_name="enduser",
query_type="find_all",
budget_id_list=[budget_id],
)
end_user = [user for user in end_users if user.user_id == user_id][0]
end_user.spend = spend
await prisma_client.update_data(
query_type="update_many",
data_list=[end_user],
table_name="enduser",
)
return response_data
async def delete_budget(session, budget_id):
url = "http://0.0.0.0:4000/budget/delete"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"id": budget_id}
async with session.post(url, headers=headers, json=data) as response:
assert response.status == 200
print(f"Deleted Budget {budget_id}")
async def delete_end_user(session, user_id):
url = "http://0.0.0.0:4000/end_user/delete"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"user_ids": [user_id]}
async with session.post(url, headers=headers, json=data) as response:
assert response.status == 200
print(f"Deleted End User {user_id}")
@pytest.fixture
def prisma_client():
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
return prisma_client
class MockProxyLogging:
class MockServiceLogging:
async def async_service_success_hook(self, **kwargs):
pass
async def async_service_failure_hook(self, **kwargs):
pass
def __init__(self):
self.service_logging_obj = self.MockServiceLogging()
@pytest.fixture
def mock_proxy_logging():
return MockProxyLogging()
@pytest.fixture
def reset_budget_job(prisma_client, mock_proxy_logging):
return ResetBudgetJob(
proxy_logging_obj=mock_proxy_logging, prisma_client=prisma_client
)
@pytest_asyncio.fixture
async def budget_and_enduser_setup(prisma_client):
"""
Fixture to set up budgets and end users for testing and clean them up afterward.
This fixture performs the following:
- Creates two budgets:
* Budget X with a short duration ("5s").
* Budget Y with a long duration ("30d").
- Stores the initial 'budget_reset_at' timestamps for both budgets.
- Creates three end users:
* End Users A and B are associated with Budget X and are given initial spend values.
* End User C is associated with Budget Y with an initial spend.
- After the test (after the yield), the created end users and budgets are deleted.
"""
await prisma_client.connect()
async with aiohttp.ClientSession() as session:
# Create budgets
id_budget_x = f"budget-{uuid.uuid4()}"
data_budget_x = {
"budget_id": id_budget_x,
"budget_duration": "5s",
"max_budget": 2,
}
id_budget_y = f"budget-{uuid.uuid4()}"
data_budget_y = {
"budget_id": id_budget_y,
"budget_duration": "30d",
"max_budget": 1,
}
response_budget_x = await create_budget(session, data_budget_x)
initial_budget_x_reset_at_date = datetime.fromisoformat(
response_budget_x["budget_reset_at"]
)
response_budget_y = await create_budget(session, data_budget_y)
initial_budget_y_reset_at_date = datetime.fromisoformat(
response_budget_y["budget_reset_at"]
)
# Create end users
id_end_user_a = f"test-user-{uuid.uuid4()}"
id_end_user_b = f"test-user-{uuid.uuid4()}"
id_end_user_c = f"test-user-{uuid.uuid4()}"
await create_end_user(
prisma_client, session, id_end_user_a, id_budget_x, spend=0.16
)
await create_end_user(
prisma_client, session, id_end_user_b, id_budget_x, spend=0.04
)
await create_end_user(
prisma_client, session, id_end_user_c, id_budget_y, spend=0.2
)
# Bundle data needed for the test
setup_data = {
"budgets": {
"id_budget_x": id_budget_x,
"id_budget_y": id_budget_y,
"initial_budget_x_reset_at_date": initial_budget_x_reset_at_date,
"initial_budget_y_reset_at_date": initial_budget_y_reset_at_date,
},
"end_users": {
"id_end_user_a": id_end_user_a,
"id_end_user_b": id_end_user_b,
"id_end_user_c": id_end_user_c,
},
}
# Provide the setup data to the test
yield setup_data
# Clean-up: Delete the created test data
await delete_end_user(session, id_end_user_a)
await delete_end_user(session, id_end_user_b)
await delete_end_user(session, id_end_user_c)
await delete_budget(session, id_budget_x)
await delete_budget(session, id_budget_y)
@pytest.mark.asyncio
async def test_reset_budget_for_endusers(
reset_budget_job, prisma_client, budget_and_enduser_setup
):
"""
Test the part "Reset End-User (Customer) Spend and corresponding Budget duration" in reset_budget function.
This test uses the budget_and_enduser_setup fixture to create budgets and end users,
waits for the short-duration budget to expire, calls reset_budget, and verifies that:
- End users associated with the short-duration budget X have their spend reset to 0.
- The budget_reset_at timestamp for the short-duration budget X is updated,
while the long-duration budget Y remains unchanged.
"""
# Unpack the required data from the fixture
budgets = budget_and_enduser_setup["budgets"]
end_users = budget_and_enduser_setup["end_users"]
id_budget_x = budgets["id_budget_x"]
id_budget_y = budgets["id_budget_y"]
initial_budget_x_reset_at_date = budgets["initial_budget_x_reset_at_date"]
initial_budget_y_reset_at_date = budgets["initial_budget_y_reset_at_date"]
id_end_user_a = end_users["id_end_user_a"]
id_end_user_b = end_users["id_end_user_b"]
id_end_user_c = end_users["id_end_user_c"]
# Wait for Budget X to expire (short duration "5s" plus a small buffer)
await asyncio.sleep(6)
# Call the reset_budget function:
# It should reset the spend values for end users associated with Budget X.
await reset_budget_job.reset_budget_for_litellm_endusers()
# Retrieve updated data for end users
updated_end_users = await prisma_client.get_data(
table_name="enduser",
query_type="find_all",
budget_id_list=[id_budget_x, id_budget_y],
)
# Retrieve updated data for budgets
updated_budgets = await prisma_client.get_data(
table_name="budget",
query_type="find_all",
reset_at=datetime.now(timezone.utc) + timedelta(days=31),
)
# Assertions for end users
user_a = [user for user in updated_end_users if user.user_id == id_end_user_a][0]
user_b = [user for user in updated_end_users if user.user_id == id_end_user_b][0]
user_c = [user for user in updated_end_users if user.user_id == id_end_user_c][0]
assert user_a.spend == 0, "Spend for end_user_a was not reset to 0"
assert user_b.spend == 0, "Spend for end_user_b was not reset to 0"
assert user_c.spend > 0, "Spend for end_user_c should not be reset"
# Assertions for budgets
budget_x = [
budget for budget in updated_budgets if budget.budget_id == id_budget_x
][0]
budget_y = [
budget for budget in updated_budgets if budget.budget_id == id_budget_y
][0]
assert (
budget_x.budget_reset_at > initial_budget_x_reset_at_date
), "Budget X budget_reset_at was not updated"
assert (
budget_y.budget_reset_at == initial_budget_y_reset_at_date
), "Budget Y budget_reset_at should remain unchanged"

View file

@ -1792,4 +1792,4 @@ async def test_get_admin_team_ids(
where={"team_id": {"in": user_info.teams}}
)
else:
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()

View file

@ -0,0 +1,90 @@
# What is this?
## Unit tests for the /budget/* endpoints
import uuid
from datetime import datetime, timedelta
import aiohttp
import pytest
import pytest_asyncio
async def delete_budget(session, budget_id):
url = "http://0.0.0.0:4000/budget/delete"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"id": budget_id}
async with session.post(url, headers=headers, json=data) as response:
assert response.status == 200
print(f"Deleted Budget {budget_id}")
async def create_budget(session, data):
url = "http://0.0.0.0:4000/budget/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
async with session.post(url, headers=headers, json=data) as response:
assert response.status == 200
response_data = await response.json()
budget_id = response_data["budget_id"]
print(f"Created Budget {budget_id}")
return response_data
@pytest_asyncio.fixture
async def budget_setup():
"""
Fixture to create a budget for testing and clean it up afterward.
This fixture performs the following steps:
1. Opens an aiohttp ClientSession.
2. Generates a random budget_id and defines the budget data (duration: 1 day, max_budget: 0.02).
3. Calls create_budget to create the budget.
4. Yields the budget_response (a dict) for use in the test.
5. After the test completes, deletes the created budget by calling delete_budget.
Returns:
dict: The JSON response from create_budget, which includes the created budget's data.
"""
async with aiohttp.ClientSession() as session:
# Generate a unique budget_id and define the budget data.
budget_id = f"budget-{uuid.uuid4()}"
data = {"budget_id": budget_id, "budget_duration": "1d", "max_budget": 0.02}
budget_response = await create_budget(session, data)
# Yield the response so the test can use it.
yield budget_response
# After the test, delete the created budget to clean up.
await delete_budget(session, budget_id)
@pytest.mark.asyncio
async def test_create_budget_with_duration(budget_setup):
"""
Test creating a budget with a specified duration and verify that the 'budget_reset_at'
timestamp is correctly calculated as 'created_at' plus the budget duration (one day).
This test uses the budget_setup fixture, which handles both the creation and cleanup of the budget.
"""
# Verify that the response includes a 'budget_reset_at' timestamp.
assert (
budget_setup["budget_reset_at"] is not None
), "The budget_reset_at field should not be None"
# Calculate the expected reset time: created_at + 1 day.
expected_reset_at_date = datetime.fromisoformat(
budget_setup["created_at"]
) + timedelta(days=1)
# Allow for a small tolerance in seconds for the timestamp calculation.
tolerance_seconds = 3
actual_reset_at_date = datetime.fromisoformat(budget_setup["budget_reset_at"])
time_difference = abs(
(actual_reset_at_date - expected_reset_at_date).total_seconds()
)
assert time_difference <= tolerance_seconds, (
f"Expected budget_reset_at to be within {tolerance_seconds} seconds of {expected_reset_at_date}, "
f"but the difference was {time_difference} seconds."
)