mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
Merge 587723996f
into b82af5b826
This commit is contained in:
commit
e6593570a6
9 changed files with 1022 additions and 47 deletions
|
@ -885,6 +885,10 @@ class BudgetNewRequest(LiteLLMPydanticObjectBase):
|
|||
default=None,
|
||||
description="Max budget for each model (e.g. {'gpt-4o': {'max_budget': '0.0000001', 'budget_duration': '1d', 'tpm_limit': 1000, 'rpm_limit': 1000}})",
|
||||
)
|
||||
budget_reset_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
description="Datetime when the budget is reset",
|
||||
)
|
||||
|
||||
|
||||
class BudgetRequest(LiteLLMPydanticObjectBase):
|
||||
|
@ -1206,10 +1210,18 @@ class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase):
|
|||
rpm_limit: Optional[int] = None
|
||||
model_max_budget: Optional[dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable):
|
||||
"""Represents all params for a LiteLLM_BudgetTable record"""
|
||||
|
||||
budget_id: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
|
||||
"""
|
||||
Used to track spend of a user_id within a team_id
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_BudgetTableFull,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_VerificationToken,
|
||||
|
@ -44,6 +46,141 @@ class ResetBudgetJob:
|
|||
## Reset Team Budget
|
||||
await self.reset_budget_for_litellm_teams()
|
||||
|
||||
### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ###
|
||||
await self.reset_budget_for_litellm_endusers()
|
||||
|
||||
async def reset_budget_for_litellm_endusers(self):
|
||||
"""
|
||||
Resets the budget for all LiteLLM End-Users (Customers) if their budget has expired
|
||||
The corresponding Budget duration is also updated.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = time.time()
|
||||
endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None
|
||||
budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None
|
||||
updated_endusers: List[LiteLLM_EndUserTable] = []
|
||||
failed_endusers = []
|
||||
try:
|
||||
budgets_to_reset = await self.prisma_client.get_data(
|
||||
table_name="budget", query_type="find_all", reset_at=now
|
||||
)
|
||||
|
||||
if budgets_to_reset is not None and len(budgets_to_reset) > 0:
|
||||
for budget in budgets_to_reset:
|
||||
budget = await ResetBudgetJob._reset_budget_reset_at_date(
|
||||
budget, now
|
||||
)
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=budgets_to_reset,
|
||||
table_name="budget",
|
||||
)
|
||||
|
||||
endusers_to_reset = await self.prisma_client.get_data(
|
||||
table_name="enduser",
|
||||
query_type="find_all",
|
||||
budget_id_list=[budget.budget_id for budget in budgets_to_reset],
|
||||
)
|
||||
|
||||
if endusers_to_reset is not None and len(endusers_to_reset) > 0:
|
||||
for enduser in endusers_to_reset:
|
||||
try:
|
||||
updated_enduser = (
|
||||
await ResetBudgetJob._reset_budget_for_enduser(
|
||||
enduser=enduser
|
||||
)
|
||||
)
|
||||
if updated_enduser is not None:
|
||||
updated_endusers.append(updated_enduser)
|
||||
else:
|
||||
failed_endusers.append(
|
||||
{
|
||||
"enduser": enduser,
|
||||
"error": "Returned None without exception",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_endusers.append({"enduser": enduser, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for enduser: %s", enduser
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated users %s",
|
||||
json.dumps(updated_endusers, indent=4, default=str),
|
||||
)
|
||||
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_endusers,
|
||||
table_name="enduser",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_endusers) > 0: # If any endusers failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_endusers)} endusers: {json.dumps(failed_endusers, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_endusers",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_budgets_found": len(budgets_to_reset)
|
||||
if budgets_to_reset
|
||||
else 0,
|
||||
"budgets_found": json.dumps(
|
||||
budgets_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_endusers_found": len(endusers_to_reset)
|
||||
if endusers_to_reset
|
||||
else 0,
|
||||
"endusers_found": json.dumps(
|
||||
endusers_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_endusers_updated": len(updated_endusers),
|
||||
"endusers_updated": json.dumps(
|
||||
updated_endusers, indent=4, default=str
|
||||
),
|
||||
"num_endusers_failed": len(failed_endusers),
|
||||
"endusers_failed": json.dumps(
|
||||
failed_endusers, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_endusers",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_budgets_found": len(budgets_to_reset)
|
||||
if budgets_to_reset
|
||||
else 0,
|
||||
"budgets_found": json.dumps(
|
||||
budgets_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_endusers_found": len(endusers_to_reset)
|
||||
if endusers_to_reset
|
||||
else 0,
|
||||
"endusers_found": json.dumps(
|
||||
endusers_to_reset, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e)
|
||||
|
||||
async def reset_budget_for_litellm_keys(self):
|
||||
"""
|
||||
Resets the budget for all the litellm keys
|
||||
|
@ -355,6 +492,46 @@ class ResetBudgetJob:
|
|||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_enduser(
|
||||
enduser: LiteLLM_EndUserTable,
|
||||
) -> Optional[LiteLLM_EndUserTable]:
|
||||
try:
|
||||
enduser.spend = 0.0
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error resetting budget for enduser: %s. Item: %s", e, enduser
|
||||
)
|
||||
raise e
|
||||
return enduser
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_reset_at_date(
|
||||
budget: LiteLLM_BudgetTableFull, current_time: datetime
|
||||
) -> Optional[LiteLLM_BudgetTableFull]:
|
||||
try:
|
||||
if budget.budget_duration is not None:
|
||||
duration_s = duration_in_seconds(duration=budget.budget_duration)
|
||||
|
||||
# Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account
|
||||
if (
|
||||
budget.budget_reset_at is None
|
||||
and budget.created_at + timedelta(seconds=duration_s) > current_time
|
||||
):
|
||||
budget.budget_reset_at = budget.created_at + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
else:
|
||||
budget.budget_reset_at = current_time + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error resetting budget_reset_at for budget: %s. Item: %s", e, budget
|
||||
)
|
||||
raise e
|
||||
return budget
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_key(
|
||||
key: LiteLLM_VerificationToken, current_time: datetime
|
||||
|
|
|
@ -12,8 +12,10 @@ All /budget management endpoints
|
|||
"""
|
||||
|
||||
#### BUDGET TABLE MANAGEMENT ####
|
||||
from datetime import timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.utils import jsonify_object
|
||||
|
@ -51,6 +53,12 @@ async def new_budget(
|
|||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# if no budget_reset_at date is set, but a budget_duration is given, then set budget_reset_at initially to the first completed duration interval in future
|
||||
if budget_obj.budget_reset_at is None and budget_obj.budget_duration is not None:
|
||||
budget_obj.budget_reset_at = datetime.utcnow() + timedelta(
|
||||
seconds=duration_in_seconds(duration=budget_obj.budget_duration)
|
||||
)
|
||||
|
||||
budget_obj_json = budget_obj.model_dump(exclude_none=True)
|
||||
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
|
|
|
@ -1410,6 +1410,8 @@ class PrismaClient:
|
|||
"key",
|
||||
"config",
|
||||
"spend",
|
||||
"enduser",
|
||||
"budget",
|
||||
"team",
|
||||
"user_notification",
|
||||
"combined_view",
|
||||
|
@ -1424,6 +1426,7 @@ class PrismaClient:
|
|||
] = None, # pagination, number of rows to getch when find_all==True
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
budget_id_list: Optional[List[str]] = None,
|
||||
):
|
||||
args_passed_in = locals()
|
||||
start_time = time.time()
|
||||
|
@ -1603,6 +1606,29 @@ class PrismaClient:
|
|||
order={"startTime": "desc"},
|
||||
)
|
||||
return response
|
||||
elif table_name == "budget" and reset_at is not None:
|
||||
if query_type == "find_all":
|
||||
response = await self.db.litellm_budgettable.find_many(
|
||||
where={ # type:ignore
|
||||
"OR": [
|
||||
{
|
||||
"AND": [
|
||||
{"budget_reset_at": None},
|
||||
{"NOT": {"budget_duration": None}},
|
||||
]
|
||||
},
|
||||
{"budget_reset_at": {"lt": reset_at}},
|
||||
]
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
elif table_name == "enduser" and budget_id_list is not None:
|
||||
if query_type == "find_all":
|
||||
response = await self.db.litellm_endusertable.find_many(
|
||||
where={"budget_id": {"in": budget_id_list}}
|
||||
)
|
||||
return response
|
||||
elif table_name == "team":
|
||||
if query_type == "find_unique":
|
||||
response = await self.db.litellm_teamtable.find_unique(
|
||||
|
@ -1916,7 +1942,9 @@ class PrismaClient:
|
|||
user_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
query_type: Literal["update", "update_many"] = "update",
|
||||
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
|
||||
table_name: Optional[
|
||||
Literal["user", "key", "config", "spend", "team", "enduser", "budget"]
|
||||
] = None,
|
||||
update_key_values: Optional[dict] = None,
|
||||
update_key_values_custom_query: Optional[dict] = None,
|
||||
):
|
||||
|
@ -2083,6 +2111,68 @@ class PrismaClient:
|
|||
verbose_proxy_logger.info(
|
||||
"\033[91m" + "DB User Table Batch update succeeded" + "\033[0m"
|
||||
)
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "enduser"
|
||||
and query_type == "update_many"
|
||||
and data_list is not None
|
||||
and isinstance(data_list, list)
|
||||
):
|
||||
"""
|
||||
Batch write update queries
|
||||
"""
|
||||
batcher = self.db.batch_()
|
||||
for enduser in data_list:
|
||||
try:
|
||||
data_json = self.jsonify_object(
|
||||
data=enduser.model_dump(exclude_none=True)
|
||||
)
|
||||
except Exception:
|
||||
data_json = self.jsonify_object(data=enduser.dict())
|
||||
batcher.litellm_endusertable.upsert(
|
||||
where={"user_id": enduser.user_id}, # type: ignore
|
||||
data={
|
||||
"create": {**data_json}, # type: ignore
|
||||
"update": {
|
||||
**data_json # type: ignore
|
||||
}, # just update end-user-specified values, if it already exists
|
||||
},
|
||||
)
|
||||
await batcher.commit()
|
||||
verbose_proxy_logger.info(
|
||||
"\033[91m" + "DB End User Table Batch update succeeded" + "\033[0m"
|
||||
)
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "budget"
|
||||
and query_type == "update_many"
|
||||
and data_list is not None
|
||||
and isinstance(data_list, list)
|
||||
):
|
||||
"""
|
||||
Batch write update queries
|
||||
"""
|
||||
batcher = self.db.batch_()
|
||||
for budget in data_list:
|
||||
try:
|
||||
data_json = self.jsonify_object(
|
||||
data=budget.model_dump(exclude_none=True)
|
||||
)
|
||||
except Exception:
|
||||
data_json = self.jsonify_object(data=budget.dict())
|
||||
batcher.litellm_budgettable.upsert(
|
||||
where={"budget_id": budget.budget_id}, # type: ignore
|
||||
data={
|
||||
"create": {**data_json}, # type: ignore
|
||||
"update": {
|
||||
**data_json # type: ignore
|
||||
}, # just update end-user-specified values, if it already exists
|
||||
},
|
||||
)
|
||||
await batcher.commit()
|
||||
verbose_proxy_logger.info(
|
||||
"\033[91m" + "DB Budget Table Batch update succeeded" + "\033[0m"
|
||||
)
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "team"
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
|
@ -20,8 +17,14 @@ from litellm.proxy.utils import ProxyLogging
|
|||
# Mock classes for testing
|
||||
class MockPrismaClient:
|
||||
def __init__(self):
|
||||
self.data = {"key": [], "user": [], "team": []}
|
||||
self.updated_data = {"key": [], "user": [], "team": []}
|
||||
self.data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []}
|
||||
self.updated_data = {
|
||||
"key": [],
|
||||
"user": [],
|
||||
"team": [],
|
||||
"budget": [],
|
||||
"enduser": [],
|
||||
}
|
||||
|
||||
async def get_data(self, table_name, query_type, **kwargs):
|
||||
return self.data.get(table_name, [])
|
||||
|
@ -145,9 +148,48 @@ def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
|
|||
assert updated_team.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.now(timezone.utc)
|
||||
test_budget = type(
|
||||
"LiteLLM_BudgetTable",
|
||||
(),
|
||||
{
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "1d",
|
||||
"budget_reset_at": now,
|
||||
"budget_id": "test-budget-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_enduser = type(
|
||||
"LiteLLM_EndUserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 20.0,
|
||||
"litellm_budget_table": test_budget,
|
||||
"user_id": "test-enduser-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["budget"] = [test_budget]
|
||||
mock_prisma_client.data["enduser"] = [test_enduser]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_endusers())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["enduser"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["budget"]) == 1
|
||||
updated_enduser = mock_prisma_client.updated_data["enduser"][0]
|
||||
updated_budget = mock_prisma_client.updated_data["budget"][0]
|
||||
assert updated_enduser.spend == 0.0
|
||||
assert updated_budget.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create test objects for all three types
|
||||
test_key = type(
|
||||
|
@ -183,9 +225,32 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client):
|
|||
},
|
||||
)
|
||||
|
||||
test_budget = type(
|
||||
"LiteLLM_BudgetTable",
|
||||
(),
|
||||
{
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "1d",
|
||||
"budget_reset_at": now,
|
||||
"budget_id": "test-budget-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_enduser = type(
|
||||
"LiteLLM_EndUserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 20.0,
|
||||
"litellm_budget_table": test_budget,
|
||||
"user_id": "test-enduser-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["key"] = [test_key]
|
||||
mock_prisma_client.data["user"] = [test_user]
|
||||
mock_prisma_client.data["team"] = [test_team]
|
||||
mock_prisma_client.data["budget"] = [test_budget]
|
||||
mock_prisma_client.data["enduser"] = [test_enduser]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget())
|
||||
|
@ -194,8 +259,11 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client):
|
|||
assert len(mock_prisma_client.updated_data["key"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["user"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["team"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["enduser"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["budget"]) == 1
|
||||
|
||||
# Check that all spends were reset to 0
|
||||
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["team"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["enduser"][0].spend == 0.0
|
||||
|
|
|
@ -1,32 +1,22 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm.proxy._types import LiteLLM_BudgetTableFull
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_TeamTable,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
# Note: In our "fake" items we use dicts with fields that our fake reset functions modify.
|
||||
# In a real-world scenario, these would be instances of LiteLLM_VerificationToken, LiteLLM_UserTable, etc.
|
||||
|
@ -180,6 +170,106 @@ async def test_reset_budget_users_partial_failure():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_endusers_partial_failure():
|
||||
"""
|
||||
Test that if one enduser fails to reset, the reset loop still processes the other endusers.
|
||||
We simulate six endsers where the first fails and the others are updated.
|
||||
"""
|
||||
user1 = {
|
||||
"user_id": "user1",
|
||||
"spend": 20.0,
|
||||
"budget_id": "budget1",
|
||||
} # Will trigger simulated failure
|
||||
user2 = {
|
||||
"user_id": "user2",
|
||||
"spend": 25.0,
|
||||
"budget_id": "budget1",
|
||||
} # Should be updated
|
||||
user3 = {
|
||||
"user_id": "user3",
|
||||
"spend": 30.0,
|
||||
"budget_id": "budget1",
|
||||
} # Should be updated
|
||||
user4 = {
|
||||
"user_id": "user4",
|
||||
"spend": 35.0,
|
||||
"budget_id": "budget1",
|
||||
} # Should be updated
|
||||
user5 = {
|
||||
"user_id": "user5",
|
||||
"spend": 40.0,
|
||||
"budget_id": "budget1",
|
||||
} # Should be updated
|
||||
user6 = {
|
||||
"user_id": "user6",
|
||||
"spend": 45.0,
|
||||
"budget_id": "budget1",
|
||||
} # Should be updated
|
||||
|
||||
budget1 = LiteLLM_BudgetTableFull(
|
||||
**{
|
||||
"budget_id": "budget1",
|
||||
"max_budget": 65.0,
|
||||
"budget_duration": "2d",
|
||||
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
|
||||
}
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
|
||||
async def get_data_mock(table_name, *args, **kwargs):
|
||||
if table_name == "budget":
|
||||
return [budget1]
|
||||
elif table_name == "enduser":
|
||||
return [user1, user2, user3, user4, user5, user6]
|
||||
return []
|
||||
|
||||
prisma_client.get_data = AsyncMock()
|
||||
prisma_client.get_data.side_effect = get_data_mock
|
||||
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_enduser(enduser):
|
||||
if enduser["user_id"] == "user1":
|
||||
raise Exception("Simulated failure for user1")
|
||||
enduser["spend"] = 0.0
|
||||
return enduser
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser
|
||||
) as mock_reset_enduser:
|
||||
await job.reset_budget_for_litellm_endusers()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert mock_reset_enduser.call_count == 6
|
||||
assert prisma_client.update_data.await_count == 2
|
||||
update_call = prisma_client.update_data.call_args
|
||||
assert update_call.kwargs.get("table_name") == "enduser"
|
||||
updated_users = update_call.kwargs.get("data_list", [])
|
||||
assert len(updated_users) == 5
|
||||
assert updated_users[0]["user_id"] == "user2"
|
||||
assert updated_users[1]["user_id"] == "user3"
|
||||
assert updated_users[2]["user_id"] == "user4"
|
||||
assert updated_users[3]["user_id"] == "user5"
|
||||
assert updated_users[4]["user_id"] == "user6"
|
||||
|
||||
failure_hook_calls = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list
|
||||
)
|
||||
assert any(
|
||||
call.kwargs.get("call_type") == "reset_budget_endusers"
|
||||
for call in failure_hook_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_teams_partial_failure():
|
||||
"""
|
||||
|
@ -263,6 +353,15 @@ async def test_reset_budget_continues_other_categories_on_failure():
|
|||
user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds
|
||||
team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180}
|
||||
team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180}
|
||||
enduser1 = {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}
|
||||
budget1 = LiteLLM_BudgetTableFull(
|
||||
**{
|
||||
"budget_id": "budget1",
|
||||
"max_budget": 65.0,
|
||||
"budget_duration": "2d",
|
||||
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
|
||||
}
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
|
||||
|
@ -273,6 +372,10 @@ async def test_reset_budget_continues_other_categories_on_failure():
|
|||
return [user1, user2]
|
||||
elif table_name == "team":
|
||||
return [team1, team2]
|
||||
elif table_name == "budget":
|
||||
return [budget1]
|
||||
elif table_name == "enduser":
|
||||
return [enduser1]
|
||||
return []
|
||||
|
||||
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
|
||||
|
@ -308,13 +411,19 @@ async def test_reset_budget_continues_other_categories_on_failure():
|
|||
).isoformat()
|
||||
return team
|
||||
|
||||
async def fake_reset_enduser(enduser):
|
||||
enduser["spend"] = 0.0
|
||||
return enduser
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key
|
||||
) as mock_reset_key, patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user
|
||||
) as mock_reset_user, patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team
|
||||
) as mock_reset_team:
|
||||
) as mock_reset_team, patch.object(
|
||||
ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser
|
||||
) as mock_reset_enduser:
|
||||
# Call the overall reset_budget method.
|
||||
await job.reset_budget()
|
||||
await asyncio.sleep(0.1)
|
||||
|
@ -323,10 +432,10 @@ async def test_reset_budget_continues_other_categories_on_failure():
|
|||
called_tables = {
|
||||
call.kwargs.get("table_name") for call in prisma_client.get_data.await_args_list
|
||||
}
|
||||
assert called_tables == {"key", "user", "team"}
|
||||
assert called_tables == {"key", "user", "team", "budget", "enduser"}
|
||||
|
||||
# Verify that update_data was called three times (one per category)
|
||||
assert prisma_client.update_data.await_count == 3
|
||||
# Verify that update_data was called three times (one per category, enduser update includes two)
|
||||
assert prisma_client.update_data.await_count == 5
|
||||
calls = prisma_client.update_data.await_args_list
|
||||
|
||||
# Check keys update: both keys succeed.
|
||||
|
@ -346,9 +455,14 @@ async def test_reset_budget_continues_other_categories_on_failure():
|
|||
assert teams_call.kwargs.get("table_name") == "team"
|
||||
assert len(teams_call.kwargs.get("data_list", [])) == 2
|
||||
|
||||
# Check enduser update: enduser succeed.
|
||||
enduser_call = calls[4]
|
||||
assert enduser_call.kwargs.get("table_name") == "enduser"
|
||||
assert len(enduser_call.kwargs.get("data_list", [])) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional tests for service logger behavior (keys, users, teams)
|
||||
# Additional tests for service logger behavior (keys, users, teams, endusers)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
@ -395,9 +509,10 @@ async def test_service_logger_keys_success():
|
|||
|
||||
# Verify success hook call
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
)
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_keys_found") == len(keys)
|
||||
assert event_metadata.get("num_keys_updated") == len(keys)
|
||||
|
@ -456,9 +571,10 @@ async def test_service_logger_keys_failure():
|
|||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
)
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_keys_found") == len(keys)
|
||||
keys_found_str = event_metadata.get("keys_found", "")
|
||||
|
@ -508,9 +624,10 @@ async def test_service_logger_users_success():
|
|||
mock_verbose_exc.assert_not_called()
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
)
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_users_found") == len(users)
|
||||
assert event_metadata.get("num_users_updated") == len(users)
|
||||
|
@ -567,9 +684,10 @@ async def test_service_logger_users_failure():
|
|||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
)
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_users_found") == len(users)
|
||||
users_found_str = event_metadata.get("users_found", "")
|
||||
|
@ -618,9 +736,10 @@ async def test_service_logger_teams_success():
|
|||
mock_verbose_exc.assert_not_called()
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
)
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_teams_found") == len(teams)
|
||||
assert event_metadata.get("num_teams_updated") == len(teams)
|
||||
|
@ -677,11 +796,156 @@ async def test_service_logger_teams_failure():
|
|||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
args, kwargs = (
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
)
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_teams_found") == len(teams)
|
||||
teams_found_str = event_metadata.get("teams_found", "")
|
||||
assert "team1" in teams_found_str
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_endusers_success():
|
||||
"""
|
||||
Test that when resetting endusers succeeds the service logger success hook is called with
|
||||
the correct metadata and no exception is logged.
|
||||
"""
|
||||
endusers = [
|
||||
{"user_id": "user1", "spend": 25.0, "budget_id": "budget1"},
|
||||
{"user_id": "user2", "spend": 25.0, "budget_id": "budget1"},
|
||||
]
|
||||
budgets = [
|
||||
LiteLLM_BudgetTableFull(
|
||||
**{
|
||||
"budget_id": "budget1",
|
||||
"max_budget": 65.0,
|
||||
"budget_duration": "2d",
|
||||
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
async def fake_get_data(*, table_name, query_type, **kwargs):
|
||||
if table_name == "budget":
|
||||
return budgets
|
||||
elif table_name == "enduser":
|
||||
return endusers
|
||||
return []
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_enduser(enduser):
|
||||
enduser["spend"] = 0.0
|
||||
return enduser
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_enduser",
|
||||
side_effect=fake_reset_enduser,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_endusers()
|
||||
await asyncio.sleep(0.1)
|
||||
mock_verbose_exc.assert_not_called()
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once()
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_budgets_found") == len(budgets)
|
||||
assert event_metadata.get("num_endusers_found") == len(endusers)
|
||||
assert event_metadata.get("num_endusers_updated") == len(endusers)
|
||||
assert event_metadata.get("num_endusers_failed") == 0
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_logger_users_failure():
|
||||
"""
|
||||
Test that a failure during enduser reset calls the failure hook with appropriate metadata,
|
||||
logs the exception, and does not call the success hook.
|
||||
"""
|
||||
endusers = [
|
||||
{"user_id": "user1", "spend": 25.0, "budget_id": "budget1"},
|
||||
{"user_id": "user2", "spend": 25.0, "budget_id": "budget1"},
|
||||
]
|
||||
budgets = [
|
||||
LiteLLM_BudgetTableFull(
|
||||
**{
|
||||
"budget_id": "budget1",
|
||||
"max_budget": 65.0,
|
||||
"budget_duration": "2d",
|
||||
"created_at": datetime.now(timezone.utc) - timedelta(days=3),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
async def fake_get_data(*, table_name, query_type, **kwargs):
|
||||
if table_name == "budget":
|
||||
return budgets
|
||||
elif table_name == "enduser":
|
||||
return endusers
|
||||
return []
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.get_data = AsyncMock(side_effect=fake_get_data)
|
||||
prisma_client.update_data = AsyncMock()
|
||||
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj = MagicMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||
|
||||
job = ResetBudgetJob(proxy_logging_obj, prisma_client)
|
||||
|
||||
async def fake_reset_enduser(enduser):
|
||||
if enduser["user_id"] == "user1":
|
||||
raise Exception("Simulated failure for user1")
|
||||
enduser["spend"] = 0.0
|
||||
return enduser
|
||||
|
||||
with patch.object(
|
||||
ResetBudgetJob,
|
||||
"_reset_budget_for_enduser",
|
||||
side_effect=fake_reset_enduser,
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception"
|
||||
) as mock_verbose_exc:
|
||||
await job.reset_budget_for_litellm_endusers()
|
||||
await asyncio.sleep(0.1)
|
||||
# Verify exception logging
|
||||
assert mock_verbose_exc.call_count >= 1
|
||||
# Verify exception was logged with correct message
|
||||
assert any(
|
||||
"Failed to reset budget for enduser" in str(call.args)
|
||||
for call in mock_verbose_exc.call_args_list
|
||||
)
|
||||
|
||||
proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||
(
|
||||
args,
|
||||
kwargs,
|
||||
) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args
|
||||
event_metadata = kwargs.get("event_metadata", {})
|
||||
assert event_metadata.get("num_budgets_found") == len(budgets)
|
||||
assert event_metadata.get("num_endusers_found") == len(endusers)
|
||||
endusers_found_str = event_metadata.get("endusers_found", "")
|
||||
assert "user1" in endusers_found_str
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called()
|
||||
|
|
266
tests/local_testing/test_enduser_spend_reset.py
Normal file
266
tests/local_testing/test_enduser_spend_reset.py
Normal file
|
@ -0,0 +1,266 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def create_budget(session, data):
|
||||
url = "http://0.0.0.0:4000/budget/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
assert response.status == 200
|
||||
response_data = await response.json()
|
||||
budget_id = response_data["budget_id"]
|
||||
print(f"Created Budget {budget_id}")
|
||||
return response_data
|
||||
|
||||
|
||||
async def create_end_user(prisma_client, session, user_id, budget_id, spend=None):
|
||||
url = "http://0.0.0.0:4000/end_user/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"user_id": user_id,
|
||||
"budget_id": budget_id,
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
assert response.status == 200
|
||||
response_data = await response.json()
|
||||
end_user_id = response_data["user_id"]
|
||||
print(f"Created End User {end_user_id}")
|
||||
|
||||
if spend is not None:
|
||||
end_users = await prisma_client.get_data(
|
||||
table_name="enduser",
|
||||
query_type="find_all",
|
||||
budget_id_list=[budget_id],
|
||||
)
|
||||
end_user = [user for user in end_users if user.user_id == user_id][0]
|
||||
end_user.spend = spend
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=[end_user],
|
||||
table_name="enduser",
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
async def delete_budget(session, budget_id):
|
||||
url = "http://0.0.0.0:4000/budget/delete"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"id": budget_id}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
assert response.status == 200
|
||||
print(f"Deleted Budget {budget_id}")
|
||||
|
||||
|
||||
async def delete_end_user(session, user_id):
|
||||
url = "http://0.0.0.0:4000/end_user/delete"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"user_ids": [user_id]}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
assert response.status == 200
|
||||
print(f"Deleted End User {user_id}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prisma_client():
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
prisma_client = PrismaClient(
|
||||
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
return prisma_client
|
||||
|
||||
|
||||
class MockProxyLogging:
|
||||
class MockServiceLogging:
|
||||
async def async_service_success_hook(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
self.service_logging_obj = self.MockServiceLogging()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proxy_logging():
|
||||
return MockProxyLogging()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_budget_job(prisma_client, mock_proxy_logging):
|
||||
return ResetBudgetJob(
|
||||
proxy_logging_obj=mock_proxy_logging, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def budget_and_enduser_setup(prisma_client):
|
||||
"""
|
||||
Fixture to set up budgets and end users for testing and clean them up afterward.
|
||||
|
||||
This fixture performs the following:
|
||||
- Creates two budgets:
|
||||
* Budget X with a short duration ("5s").
|
||||
* Budget Y with a long duration ("30d").
|
||||
- Stores the initial 'budget_reset_at' timestamps for both budgets.
|
||||
- Creates three end users:
|
||||
* End Users A and B are associated with Budget X and are given initial spend values.
|
||||
* End User C is associated with Budget Y with an initial spend.
|
||||
- After the test (after the yield), the created end users and budgets are deleted.
|
||||
"""
|
||||
await prisma_client.connect()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create budgets
|
||||
id_budget_x = f"budget-{uuid.uuid4()}"
|
||||
data_budget_x = {
|
||||
"budget_id": id_budget_x,
|
||||
"budget_duration": "5s",
|
||||
"max_budget": 2,
|
||||
}
|
||||
id_budget_y = f"budget-{uuid.uuid4()}"
|
||||
data_budget_y = {
|
||||
"budget_id": id_budget_y,
|
||||
"budget_duration": "30d",
|
||||
"max_budget": 1,
|
||||
}
|
||||
response_budget_x = await create_budget(session, data_budget_x)
|
||||
initial_budget_x_reset_at_date = datetime.fromisoformat(
|
||||
response_budget_x["budget_reset_at"]
|
||||
)
|
||||
response_budget_y = await create_budget(session, data_budget_y)
|
||||
initial_budget_y_reset_at_date = datetime.fromisoformat(
|
||||
response_budget_y["budget_reset_at"]
|
||||
)
|
||||
|
||||
# Create end users
|
||||
id_end_user_a = f"test-user-{uuid.uuid4()}"
|
||||
id_end_user_b = f"test-user-{uuid.uuid4()}"
|
||||
id_end_user_c = f"test-user-{uuid.uuid4()}"
|
||||
await create_end_user(
|
||||
prisma_client, session, id_end_user_a, id_budget_x, spend=0.16
|
||||
)
|
||||
await create_end_user(
|
||||
prisma_client, session, id_end_user_b, id_budget_x, spend=0.04
|
||||
)
|
||||
await create_end_user(
|
||||
prisma_client, session, id_end_user_c, id_budget_y, spend=0.2
|
||||
)
|
||||
|
||||
# Bundle data needed for the test
|
||||
setup_data = {
|
||||
"budgets": {
|
||||
"id_budget_x": id_budget_x,
|
||||
"id_budget_y": id_budget_y,
|
||||
"initial_budget_x_reset_at_date": initial_budget_x_reset_at_date,
|
||||
"initial_budget_y_reset_at_date": initial_budget_y_reset_at_date,
|
||||
},
|
||||
"end_users": {
|
||||
"id_end_user_a": id_end_user_a,
|
||||
"id_end_user_b": id_end_user_b,
|
||||
"id_end_user_c": id_end_user_c,
|
||||
},
|
||||
}
|
||||
|
||||
# Provide the setup data to the test
|
||||
yield setup_data
|
||||
|
||||
# Clean-up: Delete the created test data
|
||||
await delete_end_user(session, id_end_user_a)
|
||||
await delete_end_user(session, id_end_user_b)
|
||||
await delete_end_user(session, id_end_user_c)
|
||||
await delete_budget(session, id_budget_x)
|
||||
await delete_budget(session, id_budget_y)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget_for_endusers(
|
||||
reset_budget_job, prisma_client, budget_and_enduser_setup
|
||||
):
|
||||
"""
|
||||
Test the part "Reset End-User (Customer) Spend and corresponding Budget duration" in reset_budget function.
|
||||
|
||||
This test uses the budget_and_enduser_setup fixture to create budgets and end users,
|
||||
waits for the short-duration budget to expire, calls reset_budget, and verifies that:
|
||||
- End users associated with the short-duration budget X have their spend reset to 0.
|
||||
- The budget_reset_at timestamp for the short-duration budget X is updated,
|
||||
while the long-duration budget Y remains unchanged.
|
||||
"""
|
||||
|
||||
# Unpack the required data from the fixture
|
||||
budgets = budget_and_enduser_setup["budgets"]
|
||||
end_users = budget_and_enduser_setup["end_users"]
|
||||
|
||||
id_budget_x = budgets["id_budget_x"]
|
||||
id_budget_y = budgets["id_budget_y"]
|
||||
initial_budget_x_reset_at_date = budgets["initial_budget_x_reset_at_date"]
|
||||
initial_budget_y_reset_at_date = budgets["initial_budget_y_reset_at_date"]
|
||||
|
||||
id_end_user_a = end_users["id_end_user_a"]
|
||||
id_end_user_b = end_users["id_end_user_b"]
|
||||
id_end_user_c = end_users["id_end_user_c"]
|
||||
|
||||
# Wait for Budget X to expire (short duration "5s" plus a small buffer)
|
||||
await asyncio.sleep(6)
|
||||
|
||||
# Call the reset_budget function:
|
||||
# It should reset the spend values for end users associated with Budget X.
|
||||
await reset_budget_job.reset_budget_for_litellm_endusers()
|
||||
|
||||
# Retrieve updated data for end users
|
||||
updated_end_users = await prisma_client.get_data(
|
||||
table_name="enduser",
|
||||
query_type="find_all",
|
||||
budget_id_list=[id_budget_x, id_budget_y],
|
||||
)
|
||||
# Retrieve updated data for budgets
|
||||
updated_budgets = await prisma_client.get_data(
|
||||
table_name="budget",
|
||||
query_type="find_all",
|
||||
reset_at=datetime.now(timezone.utc) + timedelta(days=31),
|
||||
)
|
||||
|
||||
# Assertions for end users
|
||||
user_a = [user for user in updated_end_users if user.user_id == id_end_user_a][0]
|
||||
user_b = [user for user in updated_end_users if user.user_id == id_end_user_b][0]
|
||||
user_c = [user for user in updated_end_users if user.user_id == id_end_user_c][0]
|
||||
|
||||
assert user_a.spend == 0, "Spend for end_user_a was not reset to 0"
|
||||
assert user_b.spend == 0, "Spend for end_user_b was not reset to 0"
|
||||
assert user_c.spend > 0, "Spend for end_user_c should not be reset"
|
||||
|
||||
# Assertions for budgets
|
||||
budget_x = [
|
||||
budget for budget in updated_budgets if budget.budget_id == id_budget_x
|
||||
][0]
|
||||
budget_y = [
|
||||
budget for budget in updated_budgets if budget.budget_id == id_budget_y
|
||||
][0]
|
||||
|
||||
assert (
|
||||
budget_x.budget_reset_at > initial_budget_x_reset_at_date
|
||||
), "Budget X budget_reset_at was not updated"
|
||||
assert (
|
||||
budget_y.budget_reset_at == initial_budget_y_reset_at_date
|
||||
), "Budget Y budget_reset_at should remain unchanged"
|
|
@ -1792,4 +1792,4 @@ async def test_get_admin_team_ids(
|
|||
where={"team_id": {"in": user_info.teams}}
|
||||
)
|
||||
else:
|
||||
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()
|
||||
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()
|
90
tests/test_budget_management.py
Normal file
90
tests/test_budget_management.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# What is this?
|
||||
## Unit tests for the /budget/* endpoints
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
async def delete_budget(session, budget_id):
|
||||
url = "http://0.0.0.0:4000/budget/delete"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"id": budget_id}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
assert response.status == 200
|
||||
print(f"Deleted Budget {budget_id}")
|
||||
|
||||
|
||||
async def create_budget(session, data):
|
||||
url = "http://0.0.0.0:4000/budget/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
assert response.status == 200
|
||||
response_data = await response.json()
|
||||
budget_id = response_data["budget_id"]
|
||||
print(f"Created Budget {budget_id}")
|
||||
return response_data
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def budget_setup():
|
||||
"""
|
||||
Fixture to create a budget for testing and clean it up afterward.
|
||||
|
||||
This fixture performs the following steps:
|
||||
1. Opens an aiohttp ClientSession.
|
||||
2. Generates a random budget_id and defines the budget data (duration: 1 day, max_budget: 0.02).
|
||||
3. Calls create_budget to create the budget.
|
||||
4. Yields the budget_response (a dict) for use in the test.
|
||||
5. After the test completes, deletes the created budget by calling delete_budget.
|
||||
|
||||
Returns:
|
||||
dict: The JSON response from create_budget, which includes the created budget's data.
|
||||
"""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Generate a unique budget_id and define the budget data.
|
||||
budget_id = f"budget-{uuid.uuid4()}"
|
||||
data = {"budget_id": budget_id, "budget_duration": "1d", "max_budget": 0.02}
|
||||
budget_response = await create_budget(session, data)
|
||||
|
||||
# Yield the response so the test can use it.
|
||||
yield budget_response
|
||||
|
||||
# After the test, delete the created budget to clean up.
|
||||
await delete_budget(session, budget_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_budget_with_duration(budget_setup):
|
||||
"""
|
||||
Test creating a budget with a specified duration and verify that the 'budget_reset_at'
|
||||
timestamp is correctly calculated as 'created_at' plus the budget duration (one day).
|
||||
|
||||
This test uses the budget_setup fixture, which handles both the creation and cleanup of the budget.
|
||||
"""
|
||||
|
||||
# Verify that the response includes a 'budget_reset_at' timestamp.
|
||||
assert (
|
||||
budget_setup["budget_reset_at"] is not None
|
||||
), "The budget_reset_at field should not be None"
|
||||
|
||||
# Calculate the expected reset time: created_at + 1 day.
|
||||
expected_reset_at_date = datetime.fromisoformat(
|
||||
budget_setup["created_at"]
|
||||
) + timedelta(days=1)
|
||||
|
||||
# Allow for a small tolerance in seconds for the timestamp calculation.
|
||||
tolerance_seconds = 3
|
||||
actual_reset_at_date = datetime.fromisoformat(budget_setup["budget_reset_at"])
|
||||
time_difference = abs(
|
||||
(actual_reset_at_date - expected_reset_at_date).total_seconds()
|
||||
)
|
||||
|
||||
assert time_difference <= tolerance_seconds, (
|
||||
f"Expected budget_reset_at to be within {tolerance_seconds} seconds of {expected_reset_at_date}, "
|
||||
f"but the difference was {time_difference} seconds."
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue