mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #9329 from BerriAI/litellm_fix_reset_budget_job
[Bug fix] Reset Budget Job
This commit is contained in:
commit
5400615ce8
4 changed files with 384 additions and 24 deletions
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
|
@ -318,9 +318,11 @@ class ResetBudgetJob:
|
|||
async def _reset_budget_common(
|
||||
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
|
||||
current_time: datetime,
|
||||
item_type: str,
|
||||
) -> Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken]:
|
||||
item_type: Literal["key", "team", "user"],
|
||||
):
|
||||
"""
|
||||
In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration
|
||||
|
||||
Common logic for resetting budget for a team, user, or key
|
||||
"""
|
||||
try:
|
||||
|
@ -339,19 +341,25 @@ class ResetBudgetJob:
|
|||
async def _reset_budget_for_team(
|
||||
team: LiteLLM_TeamTable, current_time: datetime
|
||||
) -> Optional[LiteLLM_TeamTable]:
|
||||
result = await ResetBudgetJob._reset_budget_common(team, current_time, "team")
|
||||
return result if isinstance(result, LiteLLM_TeamTable) else None
|
||||
await ResetBudgetJob._reset_budget_common(
|
||||
item=team, current_time=current_time, item_type="team"
|
||||
)
|
||||
return team
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_user(
|
||||
user: LiteLLM_UserTable, current_time: datetime
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
result = await ResetBudgetJob._reset_budget_common(user, current_time, "user")
|
||||
return result if isinstance(result, LiteLLM_UserTable) else None
|
||||
await ResetBudgetJob._reset_budget_common(
|
||||
item=user, current_time=current_time, item_type="user"
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_key(
|
||||
key: LiteLLM_VerificationToken, current_time: datetime
|
||||
) -> Optional[LiteLLM_VerificationToken]:
|
||||
result = await ResetBudgetJob._reset_budget_common(key, current_time, "key")
|
||||
return result if isinstance(result, LiteLLM_VerificationToken) else None
|
||||
await ResetBudgetJob._reset_budget_common(
|
||||
item=key, current_time=current_time, item_type="key"
|
||||
)
|
||||
return key
|
||||
|
|
|
@ -32,7 +32,13 @@ from fastapi import HTTPException, status
|
|||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream
|
||||
from litellm import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
Router,
|
||||
)
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||
from litellm.caching.caching import DualCache, RedisCache
|
||||
|
@ -1009,19 +1015,24 @@ class ProxyLogging:
|
|||
for callback in litellm.callbacks:
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback)
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail(
|
||||
data=request_data, event_type=GuardrailEventHooks.post_call
|
||||
if not isinstance(
|
||||
_callback, CustomGuardrail
|
||||
) or _callback.should_run_guardrail(
|
||||
data=request_data, event_type=GuardrailEventHooks.post_call
|
||||
):
|
||||
response = _callback.async_post_call_streaming_iterator_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
request_data=request_data,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def post_call_streaming_hook(
|
||||
self,
|
||||
response: str,
|
||||
|
@ -1733,13 +1744,7 @@ class PrismaClient:
|
|||
verbose_proxy_logger.info("Data Inserted into User Table")
|
||||
return new_user_row
|
||||
elif table_name == "team":
|
||||
db_data = self.jsonify_object(data=data)
|
||||
if db_data.get("members_with_roles", None) is not None and isinstance(
|
||||
db_data["members_with_roles"], list
|
||||
):
|
||||
db_data["members_with_roles"] = json.dumps(
|
||||
db_data["members_with_roles"]
|
||||
)
|
||||
db_data = self.jsonify_team_object(db_data=data)
|
||||
new_team_row = await self.db.litellm_teamtable.upsert(
|
||||
where={"team_id": data["team_id"]},
|
||||
data={
|
||||
|
@ -2010,8 +2015,8 @@ class PrismaClient:
|
|||
batcher = self.db.batch_()
|
||||
for idx, team in enumerate(data_list):
|
||||
try:
|
||||
data_json = self.jsonify_object(
|
||||
data=team.model_dump(exclude_none=True)
|
||||
data_json = self.jsonify_team_object(
|
||||
db_data=team.model_dump(exclude_none=True)
|
||||
)
|
||||
except Exception:
|
||||
data_json = self.jsonify_object(
|
||||
|
|
201
tests/litellm/proxy/common_utils/test_reset_budget_job.py
Normal file
201
tests/litellm/proxy/common_utils/test_reset_budget_job.py
Normal file
|
@ -0,0 +1,201 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
# Mock classes for testing
|
||||
class MockPrismaClient:
|
||||
def __init__(self):
|
||||
self.data = {"key": [], "user": [], "team": []}
|
||||
self.updated_data = {"key": [], "user": [], "team": []}
|
||||
|
||||
async def get_data(self, table_name, query_type, **kwargs):
|
||||
return self.data.get(table_name, [])
|
||||
|
||||
async def update_data(self, query_type, data_list, table_name):
|
||||
self.updated_data[table_name] = data_list
|
||||
return data_list
|
||||
|
||||
|
||||
class MockProxyLogging:
|
||||
class MockServiceLogging:
|
||||
async def async_service_success_hook(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
self.service_logging_obj = self.MockServiceLogging()
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def mock_prisma_client():
|
||||
return MockPrismaClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proxy_logging():
|
||||
return MockProxyLogging()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_budget_job(mock_prisma_client, mock_proxy_logging):
|
||||
return ResetBudgetJob(
|
||||
proxy_logging_obj=mock_proxy_logging, prisma_client=mock_prisma_client
|
||||
)
|
||||
|
||||
|
||||
# Helper function to run async tests
|
||||
async def run_async_test(coro):
|
||||
return await coro
|
||||
|
||||
|
||||
# Tests
|
||||
def test_reset_budget_for_key(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.utcnow()
|
||||
test_key = type(
|
||||
"LiteLLM_VerificationToken",
|
||||
(),
|
||||
{
|
||||
"spend": 100.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-key-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["key"] = [test_key]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_keys())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["key"]) == 1
|
||||
updated_key = mock_prisma_client.updated_data["key"][0]
|
||||
assert updated_key.spend == 0.0
|
||||
assert updated_key.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_for_user(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.utcnow()
|
||||
test_user = type(
|
||||
"LiteLLM_UserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 200.0,
|
||||
"budget_duration": "7d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-user-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["user"] = [test_user]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_users())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["user"]) == 1
|
||||
updated_user = mock_prisma_client.updated_data["user"][0]
|
||||
assert updated_user.spend == 0.0
|
||||
assert updated_user.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.utcnow()
|
||||
test_team = type(
|
||||
"LiteLLM_TeamTable",
|
||||
(),
|
||||
{
|
||||
"spend": 500.0,
|
||||
"budget_duration": "1mo",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-team-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["team"] = [test_team]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_teams())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["team"]) == 1
|
||||
updated_team = mock_prisma_client.updated_data["team"][0]
|
||||
assert updated_team.spend == 0.0
|
||||
assert updated_team.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Create test objects for all three types
|
||||
test_key = type(
|
||||
"LiteLLM_VerificationToken",
|
||||
(),
|
||||
{
|
||||
"spend": 100.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-key-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_user = type(
|
||||
"LiteLLM_UserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 200.0,
|
||||
"budget_duration": "7d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-user-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_team = type(
|
||||
"LiteLLM_TeamTable",
|
||||
(),
|
||||
{
|
||||
"spend": 500.0,
|
||||
"budget_duration": "1mo",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-team-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["key"] = [test_key]
|
||||
mock_prisma_client.data["user"] = [test_user]
|
||||
mock_prisma_client.data["team"] = [test_team]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["key"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["user"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["team"]) == 1
|
||||
|
||||
# Check that all spends were reset to 0
|
||||
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["team"][0].spend == 0.0
|
|
@ -3879,3 +3879,149 @@ async def test_get_paginated_teams(prisma_client):
|
|||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
pytest.fail(f"Test failed with exception: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("entity_type", ["key", "user", "team"])
|
||||
async def test_reset_budget_job(prisma_client, entity_type):
|
||||
"""
|
||||
Test that the ResetBudgetJob correctly resets budgets for keys, users, and teams.
|
||||
|
||||
For each entity type:
|
||||
1. Create a new entity with max_budget=100, spend=99, budget_duration=5s
|
||||
2. Call the reset_budget function
|
||||
3. Verify the entity's spend is reset to 0 and budget_reset_at is updated
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
# Setup
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
||||
reset_budget_job = ResetBudgetJob(
|
||||
proxy_logging_obj=proxy_logging_obj, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
# Create entity based on type
|
||||
entity_id = None
|
||||
if entity_type == "key":
|
||||
# Create a key with specific budget settings
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
max_budget=100,
|
||||
budget_duration="5s",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
entity_id = key.token_id
|
||||
print("generated key=", key)
|
||||
|
||||
# Update the key to set spend and reset_at to now
|
||||
updated = await prisma_client.db.litellm_verificationtoken.update_many(
|
||||
where={"token": key.token_id},
|
||||
data={
|
||||
"spend": 99.0,
|
||||
},
|
||||
)
|
||||
print("Updated key=", updated)
|
||||
|
||||
elif entity_type == "user":
|
||||
# Create a user with specific budget settings
|
||||
user = await new_user(
|
||||
data=NewUserRequest(
|
||||
max_budget=100,
|
||||
budget_duration="5s",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
entity_id = user.user_id
|
||||
|
||||
# Update the user to set spend and reset_at to now
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
where={"user_id": user.user_id},
|
||||
data={
|
||||
"spend": 99.0,
|
||||
},
|
||||
)
|
||||
|
||||
elif entity_type == "team":
|
||||
# Create a team with specific budget settings
|
||||
team_id = f"test-team-{uuid.uuid4()}"
|
||||
team = await new_team(
|
||||
NewTeamRequest(
|
||||
team_id=team_id,
|
||||
max_budget=100,
|
||||
budget_duration="5s",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
)
|
||||
entity_id = team_id
|
||||
|
||||
# Update the team to set spend and reset_at to now
|
||||
current_time = datetime.utcnow()
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id},
|
||||
data={
|
||||
"spend": 99.0,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify entity was created and updated with spend
|
||||
if entity_type == "key":
|
||||
entity_before = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": entity_id}
|
||||
)
|
||||
elif entity_type == "user":
|
||||
entity_before = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": entity_id}
|
||||
)
|
||||
elif entity_type == "team":
|
||||
entity_before = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": entity_id}
|
||||
)
|
||||
|
||||
assert entity_before is not None
|
||||
assert entity_before.spend == 99.0
|
||||
|
||||
# Wait for 5 seconds to pass
|
||||
print("sleeping for 5 seconds")
|
||||
time.sleep(5)
|
||||
|
||||
# Call the reset_budget function
|
||||
await reset_budget_job.reset_budget()
|
||||
|
||||
# Verify the entity's spend is reset and budget_reset_at is updated
|
||||
if entity_type == "key":
|
||||
entity_after = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": entity_id}
|
||||
)
|
||||
elif entity_type == "user":
|
||||
entity_after = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": entity_id}
|
||||
)
|
||||
elif entity_type == "team":
|
||||
entity_after = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": entity_id}
|
||||
)
|
||||
|
||||
assert entity_after is not None
|
||||
assert entity_after.spend == 0.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue