Merge pull request #9329 from BerriAI/litellm_fix_reset_budget_job

[Bug fix] Reset Budget Job
This commit is contained in:
Ishaan Jaff 2025-03-17 21:46:08 -07:00 committed by GitHub
commit 5400615ce8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 384 additions and 24 deletions

View file

@ -2,7 +2,7 @@ import asyncio
import json
import time
from datetime import datetime, timedelta
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
@ -318,9 +318,11 @@ class ResetBudgetJob:
async def _reset_budget_common(
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
current_time: datetime,
item_type: str,
) -> Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken]:
item_type: Literal["key", "team", "user"],
):
"""
In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration
Common logic for resetting budget for a team, user, or key
"""
try:
@ -339,19 +341,25 @@ class ResetBudgetJob:
async def _reset_budget_for_team(
team: LiteLLM_TeamTable, current_time: datetime
) -> Optional[LiteLLM_TeamTable]:
result = await ResetBudgetJob._reset_budget_common(team, current_time, "team")
return result if isinstance(result, LiteLLM_TeamTable) else None
await ResetBudgetJob._reset_budget_common(
item=team, current_time=current_time, item_type="team"
)
return team
@staticmethod
async def _reset_budget_for_user(
user: LiteLLM_UserTable, current_time: datetime
) -> Optional[LiteLLM_UserTable]:
result = await ResetBudgetJob._reset_budget_common(user, current_time, "user")
return result if isinstance(result, LiteLLM_UserTable) else None
await ResetBudgetJob._reset_budget_common(
item=user, current_time=current_time, item_type="user"
)
return user
@staticmethod
async def _reset_budget_for_key(
key: LiteLLM_VerificationToken, current_time: datetime
) -> Optional[LiteLLM_VerificationToken]:
result = await ResetBudgetJob._reset_budget_common(key, current_time, "key")
return result if isinstance(result, LiteLLM_VerificationToken) else None
await ResetBudgetJob._reset_budget_common(
item=key, current_time=current_time, item_type="key"
)
return key

View file

@ -32,7 +32,13 @@ from fastapi import HTTPException, status
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream
from litellm import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
Router,
)
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching.caching import DualCache, RedisCache
@ -1009,19 +1015,24 @@ class ProxyLogging:
for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None
if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback)
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
if not isinstance(
_callback, CustomGuardrail
) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
):
response = _callback.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data
user_api_key_dict=user_api_key_dict,
response=response,
request_data=request_data,
)
return response
async def post_call_streaming_hook(
self,
response: str,
@ -1733,13 +1744,7 @@ class PrismaClient:
verbose_proxy_logger.info("Data Inserted into User Table")
return new_user_row
elif table_name == "team":
db_data = self.jsonify_object(data=data)
if db_data.get("members_with_roles", None) is not None and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(
db_data["members_with_roles"]
)
db_data = self.jsonify_team_object(db_data=data)
new_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": data["team_id"]},
data={
@ -2010,8 +2015,8 @@ class PrismaClient:
batcher = self.db.batch_()
for idx, team in enumerate(data_list):
try:
data_json = self.jsonify_object(
data=team.model_dump(exclude_none=True)
data_json = self.jsonify_team_object(
db_data=team.model_dump(exclude_none=True)
)
except Exception:
data_json = self.jsonify_object(

View file

@ -0,0 +1,201 @@
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timedelta
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.utils import ProxyLogging
# Mock classes for testing
class MockPrismaClient:
def __init__(self):
self.data = {"key": [], "user": [], "team": []}
self.updated_data = {"key": [], "user": [], "team": []}
async def get_data(self, table_name, query_type, **kwargs):
return self.data.get(table_name, [])
async def update_data(self, query_type, data_list, table_name):
self.updated_data[table_name] = data_list
return data_list
class MockProxyLogging:
class MockServiceLogging:
async def async_service_success_hook(self, **kwargs):
pass
async def async_service_failure_hook(self, **kwargs):
pass
def __init__(self):
self.service_logging_obj = self.MockServiceLogging()
# Test fixtures
@pytest.fixture
def mock_prisma_client():
return MockPrismaClient()
@pytest.fixture
def mock_proxy_logging():
return MockProxyLogging()
@pytest.fixture
def reset_budget_job(mock_prisma_client, mock_proxy_logging):
return ResetBudgetJob(
proxy_logging_obj=mock_proxy_logging, prisma_client=mock_prisma_client
)
# Helper function to run async tests
async def run_async_test(coro):
return await coro
# Tests
def test_reset_budget_for_key(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
test_key = type(
"LiteLLM_VerificationToken",
(),
{
"spend": 100.0,
"budget_duration": "30d",
"budget_reset_at": now,
"id": "test-key-1",
},
)
mock_prisma_client.data["key"] = [test_key]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_keys())
# Verify results
assert len(mock_prisma_client.updated_data["key"]) == 1
updated_key = mock_prisma_client.updated_data["key"][0]
assert updated_key.spend == 0.0
assert updated_key.budget_reset_at > now
def test_reset_budget_for_user(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
test_user = type(
"LiteLLM_UserTable",
(),
{
"spend": 200.0,
"budget_duration": "7d",
"budget_reset_at": now,
"id": "test-user-1",
},
)
mock_prisma_client.data["user"] = [test_user]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_users())
# Verify results
assert len(mock_prisma_client.updated_data["user"]) == 1
updated_user = mock_prisma_client.updated_data["user"][0]
assert updated_user.spend == 0.0
assert updated_user.budget_reset_at > now
def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
test_team = type(
"LiteLLM_TeamTable",
(),
{
"spend": 500.0,
"budget_duration": "1mo",
"budget_reset_at": now,
"id": "test-team-1",
},
)
mock_prisma_client.data["team"] = [test_team]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_teams())
# Verify results
assert len(mock_prisma_client.updated_data["team"]) == 1
updated_team = mock_prisma_client.updated_data["team"][0]
assert updated_team.spend == 0.0
assert updated_team.budget_reset_at > now
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
# Create test objects for all three types
test_key = type(
"LiteLLM_VerificationToken",
(),
{
"spend": 100.0,
"budget_duration": "30d",
"budget_reset_at": now,
"id": "test-key-1",
},
)
test_user = type(
"LiteLLM_UserTable",
(),
{
"spend": 200.0,
"budget_duration": "7d",
"budget_reset_at": now,
"id": "test-user-1",
},
)
test_team = type(
"LiteLLM_TeamTable",
(),
{
"spend": 500.0,
"budget_duration": "1mo",
"budget_reset_at": now,
"id": "test-team-1",
},
)
mock_prisma_client.data["key"] = [test_key]
mock_prisma_client.data["user"] = [test_user]
mock_prisma_client.data["team"] = [test_team]
# Run the test
asyncio.run(reset_budget_job.reset_budget())
# Verify results
assert len(mock_prisma_client.updated_data["key"]) == 1
assert len(mock_prisma_client.updated_data["user"]) == 1
assert len(mock_prisma_client.updated_data["team"]) == 1
# Check that all spends were reset to 0
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
assert mock_prisma_client.updated_data["team"][0].spend == 0.0

View file

@ -3879,3 +3879,149 @@ async def test_get_paginated_teams(prisma_client):
except Exception as e:
print(f"Error occurred: {e}")
pytest.fail(f"Test failed with exception: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("entity_type", ["key", "user", "team"])
async def test_reset_budget_job(prisma_client, entity_type):
"""
Test that the ResetBudgetJob correctly resets budgets for keys, users, and teams.
For each entity type:
1. Create a new entity with max_budget=100, spend=99, budget_duration=5s
2. Call the reset_budget function
3. Verify the entity's spend is reset to 0 and budget_reset_at is updated
"""
from datetime import datetime, timedelta
import time
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.utils import ProxyLogging
# Setup
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
reset_budget_job = ResetBudgetJob(
proxy_logging_obj=proxy_logging_obj, prisma_client=prisma_client
)
# Create entity based on type
entity_id = None
if entity_type == "key":
# Create a key with specific budget settings
key = await generate_key_fn(
data=GenerateKeyRequest(
max_budget=100,
budget_duration="5s",
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
entity_id = key.token_id
print("generated key=", key)
# Update the key to set spend and reset_at to now
updated = await prisma_client.db.litellm_verificationtoken.update_many(
where={"token": key.token_id},
data={
"spend": 99.0,
},
)
print("Updated key=", updated)
elif entity_type == "user":
# Create a user with specific budget settings
user = await new_user(
data=NewUserRequest(
max_budget=100,
budget_duration="5s",
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
entity_id = user.user_id
# Update the user to set spend and reset_at to now
await prisma_client.db.litellm_usertable.update_many(
where={"user_id": user.user_id},
data={
"spend": 99.0,
},
)
elif entity_type == "team":
# Create a team with specific budget settings
team_id = f"test-team-{uuid.uuid4()}"
team = await new_team(
NewTeamRequest(
team_id=team_id,
max_budget=100,
budget_duration="5s",
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
http_request=Request(scope={"type": "http"}),
)
entity_id = team_id
# Update the team to set spend and reset_at to now
current_time = datetime.utcnow()
await prisma_client.db.litellm_teamtable.update(
where={"team_id": team_id},
data={
"spend": 99.0,
},
)
# Verify entity was created and updated with spend
if entity_type == "key":
entity_before = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": entity_id}
)
elif entity_type == "user":
entity_before = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": entity_id}
)
elif entity_type == "team":
entity_before = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": entity_id}
)
assert entity_before is not None
assert entity_before.spend == 99.0
# Wait for 5 seconds to pass
print("sleeping for 5 seconds")
time.sleep(5)
# Call the reset_budget function
await reset_budget_job.reset_budget()
# Verify the entity's spend is reset and budget_reset_at is updated
if entity_type == "key":
entity_after = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": entity_id}
)
elif entity_type == "user":
entity_after = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": entity_id}
)
elif entity_type == "team":
entity_after = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": entity_id}
)
assert entity_after is not None
assert entity_after.spend == 0.0