From 91f4d4d865d75f0e2548e6fe1df63ea4676ff469 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Mar 2025 18:56:24 -0700 Subject: [PATCH 1/5] fix _reset_budget_for_key/team/user --- litellm/proxy/common_utils/reset_budget_job.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index 810b8c08c4..6a1842fc13 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -319,7 +319,7 @@ class ResetBudgetJob: item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken], current_time: datetime, item_type: str, - ) -> Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken]: + ): """ Common logic for resetting budget for a team, user, or key """ @@ -339,19 +339,19 @@ class ResetBudgetJob: async def _reset_budget_for_team( team: LiteLLM_TeamTable, current_time: datetime ) -> Optional[LiteLLM_TeamTable]: - result = await ResetBudgetJob._reset_budget_common(team, current_time, "team") - return result if isinstance(result, LiteLLM_TeamTable) else None + await ResetBudgetJob._reset_budget_common(team, current_time, "team") + return team @staticmethod async def _reset_budget_for_user( user: LiteLLM_UserTable, current_time: datetime ) -> Optional[LiteLLM_UserTable]: - result = await ResetBudgetJob._reset_budget_common(user, current_time, "user") - return result if isinstance(result, LiteLLM_UserTable) else None + await ResetBudgetJob._reset_budget_common(user, current_time, "user") + return user @staticmethod async def _reset_budget_for_key( key: LiteLLM_VerificationToken, current_time: datetime ) -> Optional[LiteLLM_VerificationToken]: - result = await ResetBudgetJob._reset_budget_common(key, current_time, "key") - return result if isinstance(result, LiteLLM_VerificationToken) else None + await ResetBudgetJob._reset_budget_common(key, current_time, "key") + return key From a9123d961d41f9cb341cd0b722e2166de160a728 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Mar 2025 19:34:44 -0700 Subject: [PATCH 2/5] fix _reset_budget_for_team --- litellm/proxy/common_utils/reset_budget_job.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index 6a1842fc13..1d50002f5c 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -2,7 +2,7 @@ import asyncio import json import time from datetime import datetime, timedelta -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.duration_parser import duration_in_seconds @@ -318,9 +318,11 @@ class ResetBudgetJob: async def _reset_budget_common( item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken], current_time: datetime, - item_type: str, + item_type: Literal["key", "team", "user"], ): """ + In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration + Common logic for resetting budget for a team, user, or key """ try: @@ -339,19 +341,25 @@ class ResetBudgetJob: async def _reset_budget_for_team( team: LiteLLM_TeamTable, current_time: datetime ) -> Optional[LiteLLM_TeamTable]: - await ResetBudgetJob._reset_budget_common(team, current_time, "team") + await ResetBudgetJob._reset_budget_common( + item=team, current_time=current_time, item_type="team" + ) return team @staticmethod async def _reset_budget_for_user( user: LiteLLM_UserTable, current_time: datetime ) -> Optional[LiteLLM_UserTable]: - await ResetBudgetJob._reset_budget_common(user, current_time, "user") + await ResetBudgetJob._reset_budget_common( + item=user, current_time=current_time, item_type="user" + ) return user @staticmethod async def _reset_budget_for_key( key: LiteLLM_VerificationToken, current_time: datetime ) -> Optional[LiteLLM_VerificationToken]: - await ResetBudgetJob._reset_budget_common(key, current_time, "key") + await ResetBudgetJob._reset_budget_common( + item=key, current_time=current_time, item_type="key" + ) return key From 63a32ff02c640cecc5ef4722f8d27a7cba903c3a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Mar 2025 19:37:09 -0700 Subject: [PATCH 3/5] test_reset_budget_job --- .../test_key_generate_prisma.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index c47a37ec6a..f922b2c27f 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -3879,3 +3879,149 @@ async def test_get_paginated_teams(prisma_client): except Exception as e: print(f"Error occurred: {e}") pytest.fail(f"Test failed with exception: {e}") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("entity_type", ["key", "user", "team"]) +async def test_reset_budget_job(prisma_client, entity_type): + """ + Test that the ResetBudgetJob correctly resets budgets for keys, users, and teams. + + For each entity type: + 1. Create a new entity with max_budget=100, spend=99, budget_duration=5s + 2. Call the reset_budget function + 3. Verify the entity's spend is reset to 0 and budget_reset_at is updated + """ + from datetime import datetime, timedelta + import time + + from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob + from litellm.proxy.utils import ProxyLogging + + # Setup + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + proxy_logging_obj = ProxyLogging(user_api_key_cache=None) + reset_budget_job = ResetBudgetJob( + proxy_logging_obj=proxy_logging_obj, prisma_client=prisma_client + ) + + # Create entity based on type + entity_id = None + if entity_type == "key": + # Create a key with specific budget settings + key = await generate_key_fn( + data=GenerateKeyRequest( + max_budget=100, + budget_duration="5s", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + entity_id = key.token_id + print("generated key=", key) + + # Update the key to set spend and reset_at to now + updated = await prisma_client.db.litellm_verificationtoken.update_many( + where={"token": key.token_id}, + data={ + "spend": 99.0, + }, + ) + print("Updated key=", updated) + + elif entity_type == "user": + # Create a user with specific budget settings + user = await new_user( + data=NewUserRequest( + max_budget=100, + budget_duration="5s", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + entity_id = user.user_id + + # Update the user to set spend and reset_at to now + await prisma_client.db.litellm_usertable.update_many( + where={"user_id": user.user_id}, + data={ + "spend": 99.0, + }, + ) + + elif entity_type == "team": + # Create a team with specific budget settings + team_id = f"test-team-{uuid.uuid4()}" + team = await new_team( + NewTeamRequest( + team_id=team_id, + max_budget=100, + budget_duration="5s", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + http_request=Request(scope={"type": "http"}), + ) + entity_id = team_id + + # Update the team to set spend and reset_at to now + current_time = datetime.utcnow() + await prisma_client.db.litellm_teamtable.update( + where={"team_id": team_id}, + data={ + "spend": 99.0, + }, + ) + + # Verify entity was created and updated with spend + if entity_type == "key": + entity_before = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": entity_id} + ) + elif entity_type == "user": + entity_before = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": entity_id} + ) + elif entity_type == "team": + entity_before = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": entity_id} + ) + + assert entity_before is not None + assert entity_before.spend == 99.0 + + # Wait for 5 seconds to pass + print("sleeping for 5 seconds") + time.sleep(5) + + # Call the reset_budget function + await reset_budget_job.reset_budget() + + # Verify the entity's spend is reset and budget_reset_at is updated + if entity_type == "key": + entity_after = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": entity_id} + ) + elif entity_type == "user": + entity_after = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": entity_id} + ) + elif entity_type == "team": + entity_after = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": entity_id} + ) + + assert entity_after is not None + assert entity_after.spend == 0.0 From 09403e209752895389ac0be617b1873551d1ac11 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Mar 2025 19:42:04 -0700 Subject: [PATCH 4/5] use jsonify_team_object for updating teams --- litellm/proxy/utils.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 399e87b145..57399a62c8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -32,7 +32,13 @@ from fastapi import HTTPException, status import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging -from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream +from litellm import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseStream, + Router, +) from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching.caching import DualCache, RedisCache @@ -1009,19 +1015,24 @@ class ProxyLogging: for callback in litellm.callbacks: _callback: Optional[CustomLogger] = None if isinstance(callback, str): - _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback) + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) else: _callback = callback # type: ignore if _callback is not None and isinstance(_callback, CustomLogger): - if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail( - data=request_data, event_type=GuardrailEventHooks.post_call + if not isinstance( + _callback, CustomGuardrail + ) or _callback.should_run_guardrail( + data=request_data, event_type=GuardrailEventHooks.post_call ): response = _callback.async_post_call_streaming_iterator_hook( - user_api_key_dict=user_api_key_dict, response=response, request_data=request_data + user_api_key_dict=user_api_key_dict, + response=response, + request_data=request_data, ) return response - async def post_call_streaming_hook( self, response: str, @@ -1733,13 +1744,7 @@ class PrismaClient: verbose_proxy_logger.info("Data Inserted into User Table") return new_user_row elif table_name == "team": - db_data = self.jsonify_object(data=data) - if db_data.get("members_with_roles", None) is not None and isinstance( - db_data["members_with_roles"], list - ): - db_data["members_with_roles"] = json.dumps( - db_data["members_with_roles"] - ) + db_data = self.jsonify_team_object(db_data=data) new_team_row = await self.db.litellm_teamtable.upsert( where={"team_id": data["team_id"]}, data={ @@ -2010,8 +2015,8 @@ class PrismaClient: batcher = self.db.batch_() for idx, team in enumerate(data_list): try: - data_json = self.jsonify_object( - data=team.model_dump(exclude_none=True) + data_json = self.jsonify_team_object( + db_data=team.model_dump(exclude_none=True) ) except Exception: data_json = self.jsonify_object( From fbace8d0415b5bbbb5062cb0fdc315dec0d51ec4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Mar 2025 19:54:51 -0700 Subject: [PATCH 5/5] unit test reset budget job --- .../common_utils/test_reset_budget_job.py | 201 ++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 tests/litellm/proxy/common_utils/test_reset_budget_job.py diff --git a/tests/litellm/proxy/common_utils/test_reset_budget_job.py b/tests/litellm/proxy/common_utils/test_reset_budget_job.py new file mode 100644 index 0000000000..bb4af00d78 --- /dev/null +++ b/tests/litellm/proxy/common_utils/test_reset_budget_job.py @@ -0,0 +1,201 @@ +import asyncio +import json +import os +import sys +import time +from datetime import datetime, timedelta + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob +from litellm.proxy.utils import ProxyLogging + + +# Mock classes for testing +class MockPrismaClient: + def __init__(self): + self.data = {"key": [], "user": [], "team": []} + self.updated_data = {"key": [], "user": [], "team": []} + + async def get_data(self, table_name, query_type, **kwargs): + return self.data.get(table_name, []) + + async def update_data(self, query_type, data_list, table_name): + self.updated_data[table_name] = data_list + return data_list + + +class MockProxyLogging: + class MockServiceLogging: + async def async_service_success_hook(self, **kwargs): + pass + + async def async_service_failure_hook(self, **kwargs): + pass + + def __init__(self): + self.service_logging_obj = self.MockServiceLogging() + + +# Test fixtures +@pytest.fixture +def mock_prisma_client(): + return MockPrismaClient() + + +@pytest.fixture +def mock_proxy_logging(): + return MockProxyLogging() + + +@pytest.fixture +def reset_budget_job(mock_prisma_client, mock_proxy_logging): + return ResetBudgetJob( + proxy_logging_obj=mock_proxy_logging, prisma_client=mock_prisma_client + ) + + +# Helper function to run async tests +async def run_async_test(coro): + return await coro + + +# Tests +def test_reset_budget_for_key(reset_budget_job, mock_prisma_client): + # Setup test data + now = datetime.utcnow() + test_key = type( + "LiteLLM_VerificationToken", + (), + { + "spend": 100.0, + "budget_duration": "30d", + "budget_reset_at": now, + "id": "test-key-1", + }, + ) + + mock_prisma_client.data["key"] = [test_key] + + # Run the test + asyncio.run(reset_budget_job.reset_budget_for_litellm_keys()) + + # Verify results + assert len(mock_prisma_client.updated_data["key"]) == 1 + updated_key = mock_prisma_client.updated_data["key"][0] + assert updated_key.spend == 0.0 + assert updated_key.budget_reset_at > now + + +def test_reset_budget_for_user(reset_budget_job, mock_prisma_client): + # Setup test data + now = datetime.utcnow() + test_user = type( + "LiteLLM_UserTable", + (), + { + "spend": 200.0, + "budget_duration": "7d", + "budget_reset_at": now, + "id": "test-user-1", + }, + ) + + mock_prisma_client.data["user"] = [test_user] + + # Run the test + asyncio.run(reset_budget_job.reset_budget_for_litellm_users()) + + # Verify results + assert len(mock_prisma_client.updated_data["user"]) == 1 + updated_user = mock_prisma_client.updated_data["user"][0] + assert updated_user.spend == 0.0 + assert updated_user.budget_reset_at > now + + +def test_reset_budget_for_team(reset_budget_job, mock_prisma_client): + # Setup test data + now = datetime.utcnow() + test_team = type( + "LiteLLM_TeamTable", + (), + { + "spend": 500.0, + "budget_duration": "1mo", + "budget_reset_at": now, + "id": "test-team-1", + }, + ) + + mock_prisma_client.data["team"] = [test_team] + + # Run the test + asyncio.run(reset_budget_job.reset_budget_for_litellm_teams()) + + # Verify results + assert len(mock_prisma_client.updated_data["team"]) == 1 + updated_team = mock_prisma_client.updated_data["team"][0] + assert updated_team.spend == 0.0 + assert updated_team.budget_reset_at > now + + +def test_reset_budget_all(reset_budget_job, mock_prisma_client): + # Setup test data + now = datetime.utcnow() + + # Create test objects for all three types + test_key = type( + "LiteLLM_VerificationToken", + (), + { + "spend": 100.0, + "budget_duration": "30d", + "budget_reset_at": now, + "id": "test-key-1", + }, + ) + + test_user = type( + "LiteLLM_UserTable", + (), + { + "spend": 200.0, + "budget_duration": "7d", + "budget_reset_at": now, + "id": "test-user-1", + }, + ) + + test_team = type( + "LiteLLM_TeamTable", + (), + { + "spend": 500.0, + "budget_duration": "1mo", + "budget_reset_at": now, + "id": "test-team-1", + }, + ) + + mock_prisma_client.data["key"] = [test_key] + mock_prisma_client.data["user"] = [test_user] + mock_prisma_client.data["team"] = [test_team] + + # Run the test + asyncio.run(reset_budget_job.reset_budget()) + + # Verify results + assert len(mock_prisma_client.updated_data["key"]) == 1 + assert len(mock_prisma_client.updated_data["user"]) == 1 + assert len(mock_prisma_client.updated_data["team"]) == 1 + + # Check that all spends were reset to 0 + assert mock_prisma_client.updated_data["key"][0].spend == 0.0 + assert mock_prisma_client.updated_data["user"][0].spend == 0.0 + assert mock_prisma_client.updated_data["team"][0].spend == 0.0