Merge pull request #9329 from BerriAI/litellm_fix_reset_budget_job

[Bug fix] Reset Budget Job
This commit is contained in:
Ishaan Jaff 2025-03-17 21:46:08 -07:00 committed by GitHub
commit 5400615ce8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 384 additions and 24 deletions

View file

@ -2,7 +2,7 @@ import asyncio
import json import json
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional, Union from typing import List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.litellm_core_utils.duration_parser import duration_in_seconds
@ -318,9 +318,11 @@ class ResetBudgetJob:
async def _reset_budget_common( async def _reset_budget_common(
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken], item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
current_time: datetime, current_time: datetime,
item_type: str, item_type: Literal["key", "team", "user"],
) -> Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken]: ):
""" """
In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration
Common logic for resetting budget for a team, user, or key Common logic for resetting budget for a team, user, or key
""" """
try: try:
@ -339,19 +341,25 @@ class ResetBudgetJob:
async def _reset_budget_for_team( async def _reset_budget_for_team(
team: LiteLLM_TeamTable, current_time: datetime team: LiteLLM_TeamTable, current_time: datetime
) -> Optional[LiteLLM_TeamTable]: ) -> Optional[LiteLLM_TeamTable]:
result = await ResetBudgetJob._reset_budget_common(team, current_time, "team") await ResetBudgetJob._reset_budget_common(
return result if isinstance(result, LiteLLM_TeamTable) else None item=team, current_time=current_time, item_type="team"
)
return team
@staticmethod @staticmethod
async def _reset_budget_for_user( async def _reset_budget_for_user(
user: LiteLLM_UserTable, current_time: datetime user: LiteLLM_UserTable, current_time: datetime
) -> Optional[LiteLLM_UserTable]: ) -> Optional[LiteLLM_UserTable]:
result = await ResetBudgetJob._reset_budget_common(user, current_time, "user") await ResetBudgetJob._reset_budget_common(
return result if isinstance(result, LiteLLM_UserTable) else None item=user, current_time=current_time, item_type="user"
)
return user
@staticmethod @staticmethod
async def _reset_budget_for_key( async def _reset_budget_for_key(
key: LiteLLM_VerificationToken, current_time: datetime key: LiteLLM_VerificationToken, current_time: datetime
) -> Optional[LiteLLM_VerificationToken]: ) -> Optional[LiteLLM_VerificationToken]:
result = await ResetBudgetJob._reset_budget_common(key, current_time, "key") await ResetBudgetJob._reset_budget_common(
return result if isinstance(result, LiteLLM_VerificationToken) else None item=key, current_time=current_time, item_type="key"
)
return key

View file

@ -32,7 +32,13 @@ from fastapi import HTTPException, status
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream from litellm import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
Router,
)
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching.caching import DualCache, RedisCache from litellm.caching.caching import DualCache, RedisCache
@ -1009,19 +1015,24 @@ class ProxyLogging:
for callback in litellm.callbacks: for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None _callback: Optional[CustomLogger] = None
if isinstance(callback, str): if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback) _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else: else:
_callback = callback # type: ignore _callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger): if _callback is not None and isinstance(_callback, CustomLogger):
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail( if not isinstance(
_callback, CustomGuardrail
) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call data=request_data, event_type=GuardrailEventHooks.post_call
): ):
response = _callback.async_post_call_streaming_iterator_hook( response = _callback.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data user_api_key_dict=user_api_key_dict,
response=response,
request_data=request_data,
) )
return response return response
async def post_call_streaming_hook( async def post_call_streaming_hook(
self, self,
response: str, response: str,
@ -1733,13 +1744,7 @@ class PrismaClient:
verbose_proxy_logger.info("Data Inserted into User Table") verbose_proxy_logger.info("Data Inserted into User Table")
return new_user_row return new_user_row
elif table_name == "team": elif table_name == "team":
db_data = self.jsonify_object(data=data) db_data = self.jsonify_team_object(db_data=data)
if db_data.get("members_with_roles", None) is not None and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(
db_data["members_with_roles"]
)
new_team_row = await self.db.litellm_teamtable.upsert( new_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": data["team_id"]}, where={"team_id": data["team_id"]},
data={ data={
@ -2010,8 +2015,8 @@ class PrismaClient:
batcher = self.db.batch_() batcher = self.db.batch_()
for idx, team in enumerate(data_list): for idx, team in enumerate(data_list):
try: try:
data_json = self.jsonify_object( data_json = self.jsonify_team_object(
data=team.model_dump(exclude_none=True) db_data=team.model_dump(exclude_none=True)
) )
except Exception: except Exception:
data_json = self.jsonify_object( data_json = self.jsonify_object(

View file

@ -0,0 +1,201 @@
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timedelta
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.utils import ProxyLogging
# Mock classes for testing
class MockPrismaClient:
def __init__(self):
self.data = {"key": [], "user": [], "team": []}
self.updated_data = {"key": [], "user": [], "team": []}
async def get_data(self, table_name, query_type, **kwargs):
return self.data.get(table_name, [])
async def update_data(self, query_type, data_list, table_name):
self.updated_data[table_name] = data_list
return data_list
class MockProxyLogging:
class MockServiceLogging:
async def async_service_success_hook(self, **kwargs):
pass
async def async_service_failure_hook(self, **kwargs):
pass
def __init__(self):
self.service_logging_obj = self.MockServiceLogging()
# Test fixtures
@pytest.fixture
def mock_prisma_client():
return MockPrismaClient()
@pytest.fixture
def mock_proxy_logging():
return MockProxyLogging()
@pytest.fixture
def reset_budget_job(mock_prisma_client, mock_proxy_logging):
return ResetBudgetJob(
proxy_logging_obj=mock_proxy_logging, prisma_client=mock_prisma_client
)
# Helper function to run async tests
async def run_async_test(coro):
return await coro
# Tests
def test_reset_budget_for_key(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
test_key = type(
"LiteLLM_VerificationToken",
(),
{
"spend": 100.0,
"budget_duration": "30d",
"budget_reset_at": now,
"id": "test-key-1",
},
)
mock_prisma_client.data["key"] = [test_key]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_keys())
# Verify results
assert len(mock_prisma_client.updated_data["key"]) == 1
updated_key = mock_prisma_client.updated_data["key"][0]
assert updated_key.spend == 0.0
assert updated_key.budget_reset_at > now
def test_reset_budget_for_user(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
test_user = type(
"LiteLLM_UserTable",
(),
{
"spend": 200.0,
"budget_duration": "7d",
"budget_reset_at": now,
"id": "test-user-1",
},
)
mock_prisma_client.data["user"] = [test_user]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_users())
# Verify results
assert len(mock_prisma_client.updated_data["user"]) == 1
updated_user = mock_prisma_client.updated_data["user"][0]
assert updated_user.spend == 0.0
assert updated_user.budget_reset_at > now
def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
test_team = type(
"LiteLLM_TeamTable",
(),
{
"spend": 500.0,
"budget_duration": "1mo",
"budget_reset_at": now,
"id": "test-team-1",
},
)
mock_prisma_client.data["team"] = [test_team]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_teams())
# Verify results
assert len(mock_prisma_client.updated_data["team"]) == 1
updated_team = mock_prisma_client.updated_data["team"][0]
assert updated_team.spend == 0.0
assert updated_team.budget_reset_at > now
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.utcnow()
# Create test objects for all three types
test_key = type(
"LiteLLM_VerificationToken",
(),
{
"spend": 100.0,
"budget_duration": "30d",
"budget_reset_at": now,
"id": "test-key-1",
},
)
test_user = type(
"LiteLLM_UserTable",
(),
{
"spend": 200.0,
"budget_duration": "7d",
"budget_reset_at": now,
"id": "test-user-1",
},
)
test_team = type(
"LiteLLM_TeamTable",
(),
{
"spend": 500.0,
"budget_duration": "1mo",
"budget_reset_at": now,
"id": "test-team-1",
},
)
mock_prisma_client.data["key"] = [test_key]
mock_prisma_client.data["user"] = [test_user]
mock_prisma_client.data["team"] = [test_team]
# Run the test
asyncio.run(reset_budget_job.reset_budget())
# Verify results
assert len(mock_prisma_client.updated_data["key"]) == 1
assert len(mock_prisma_client.updated_data["user"]) == 1
assert len(mock_prisma_client.updated_data["team"]) == 1
# Check that all spends were reset to 0
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
assert mock_prisma_client.updated_data["team"][0].spend == 0.0

View file

@ -3879,3 +3879,149 @@ async def test_get_paginated_teams(prisma_client):
except Exception as e: except Exception as e:
print(f"Error occurred: {e}") print(f"Error occurred: {e}")
pytest.fail(f"Test failed with exception: {e}") pytest.fail(f"Test failed with exception: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("entity_type", ["key", "user", "team"])
async def test_reset_budget_job(prisma_client, entity_type):
"""
Test that the ResetBudgetJob correctly resets budgets for keys, users, and teams.
For each entity type:
1. Create a new entity with max_budget=100, spend=99, budget_duration=5s
2. Call the reset_budget function
3. Verify the entity's spend is reset to 0 and budget_reset_at is updated
"""
from datetime import datetime, timedelta
import time
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.utils import ProxyLogging
# Setup
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
reset_budget_job = ResetBudgetJob(
proxy_logging_obj=proxy_logging_obj, prisma_client=prisma_client
)
# Create entity based on type
entity_id = None
if entity_type == "key":
# Create a key with specific budget settings
key = await generate_key_fn(
data=GenerateKeyRequest(
max_budget=100,
budget_duration="5s",
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
entity_id = key.token_id
print("generated key=", key)
# Update the key to set spend and reset_at to now
updated = await prisma_client.db.litellm_verificationtoken.update_many(
where={"token": key.token_id},
data={
"spend": 99.0,
},
)
print("Updated key=", updated)
elif entity_type == "user":
# Create a user with specific budget settings
user = await new_user(
data=NewUserRequest(
max_budget=100,
budget_duration="5s",
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
entity_id = user.user_id
# Update the user to set spend and reset_at to now
await prisma_client.db.litellm_usertable.update_many(
where={"user_id": user.user_id},
data={
"spend": 99.0,
},
)
elif entity_type == "team":
# Create a team with specific budget settings
team_id = f"test-team-{uuid.uuid4()}"
team = await new_team(
NewTeamRequest(
team_id=team_id,
max_budget=100,
budget_duration="5s",
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
http_request=Request(scope={"type": "http"}),
)
entity_id = team_id
# Update the team to set spend and reset_at to now
current_time = datetime.utcnow()
await prisma_client.db.litellm_teamtable.update(
where={"team_id": team_id},
data={
"spend": 99.0,
},
)
# Verify entity was created and updated with spend
if entity_type == "key":
entity_before = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": entity_id}
)
elif entity_type == "user":
entity_before = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": entity_id}
)
elif entity_type == "team":
entity_before = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": entity_id}
)
assert entity_before is not None
assert entity_before.spend == 99.0
# Wait for 5 seconds to pass
print("sleeping for 5 seconds")
time.sleep(5)
# Call the reset_budget function
await reset_budget_job.reset_budget()
# Verify the entity's spend is reset and budget_reset_at is updated
if entity_type == "key":
entity_after = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": entity_id}
)
elif entity_type == "user":
entity_after = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": entity_id}
)
elif entity_type == "team":
entity_after = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": entity_id}
)
assert entity_after is not None
assert entity_after.spend == 0.0