diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 354f6bb54c..226bbdf792 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -885,6 +885,10 @@ class BudgetNewRequest(LiteLLMPydanticObjectBase): default=None, description="Max budget for each model (e.g. {'gpt-4o': {'max_budget': '0.0000001', 'budget_duration': '1d', 'tpm_limit': 1000, 'rpm_limit': 1000}})", ) + budget_reset_at: Optional[datetime] = Field( + default=None, + description="Datetime when the budget is reset", + ) class BudgetRequest(LiteLLMPydanticObjectBase): @@ -1206,10 +1210,18 @@ class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase): rpm_limit: Optional[int] = None model_max_budget: Optional[dict] = None budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None model_config = ConfigDict(protected_namespaces=()) +class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable): + """Represents all params for a LiteLLM_BudgetTable record""" + + budget_id: str + created_at: datetime + + class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): """ Used to track spend of a user_id within a team_id diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index 1d50002f5c..0b82eecd9d 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -1,12 +1,14 @@ import asyncio import json import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import List, Literal, Optional, Union from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import ( + LiteLLM_BudgetTableFull, + LiteLLM_EndUserTable, LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken, @@ -44,6 +46,141 @@ class ResetBudgetJob: ## Reset Team Budget await self.reset_budget_for_litellm_teams() + ### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ### + await self.reset_budget_for_litellm_endusers() + + async def reset_budget_for_litellm_endusers(self): + """ + Resets the budget for all LiteLLM End-Users (Customers) if their budget has expired + The corresponding Budget duration is also updated. + """ + now = datetime.now(timezone.utc) + start_time = time.time() + endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None + budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None + updated_endusers: List[LiteLLM_EndUserTable] = [] + failed_endusers = [] + try: + budgets_to_reset = await self.prisma_client.get_data( + table_name="budget", query_type="find_all", reset_at=now + ) + + if budgets_to_reset is not None and len(budgets_to_reset) > 0: + for budget in budgets_to_reset: + budget = await ResetBudgetJob._reset_budget_reset_at_date( + budget, now + ) + await self.prisma_client.update_data( + query_type="update_many", + data_list=budgets_to_reset, + table_name="budget", + ) + + endusers_to_reset = await self.prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[budget.budget_id for budget in budgets_to_reset], + ) + + if endusers_to_reset is not None and len(endusers_to_reset) > 0: + for enduser in endusers_to_reset: + try: + updated_enduser = ( + await ResetBudgetJob._reset_budget_for_enduser( + enduser=enduser + ) + ) + if updated_enduser is not None: + updated_endusers.append(updated_enduser) + else: + failed_endusers.append( + { + "enduser": enduser, + "error": "Returned None without exception", + } + ) + except Exception as e: + failed_endusers.append({"enduser": enduser, "error": str(e)}) + verbose_proxy_logger.exception( + "Failed to reset budget for enduser: %s", enduser + ) + + verbose_proxy_logger.debug( + "Updated users %s", + json.dumps(updated_endusers, indent=4, default=str), + ) + + await self.prisma_client.update_data( + query_type="update_many", + data_list=updated_endusers, + table_name="enduser", + ) + + end_time = time.time() + if len(failed_endusers) > 0: # If any endusers failed to reset + raise Exception( + f"Failed to reset {len(failed_endusers)} endusers: {json.dumps(failed_endusers, default=str)}" + ) + + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + call_type="reset_budget_endusers", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_budgets_found": len(budgets_to_reset) + if budgets_to_reset + else 0, + "budgets_found": json.dumps( + budgets_to_reset, indent=4, default=str + ), + "num_endusers_found": len(endusers_to_reset) + if endusers_to_reset + else 0, + "endusers_found": json.dumps( + endusers_to_reset, indent=4, default=str + ), + "num_endusers_updated": len(updated_endusers), + "endusers_updated": json.dumps( + updated_endusers, indent=4, default=str + ), + "num_endusers_failed": len(failed_endusers), + "endusers_failed": json.dumps( + failed_endusers, indent=4, default=str + ), + }, + ) + ) + except Exception as e: + end_time = time.time() + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_failure_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + error=e, + call_type="reset_budget_endusers", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_budgets_found": len(budgets_to_reset) + if budgets_to_reset + else 0, + "budgets_found": json.dumps( + budgets_to_reset, indent=4, default=str + ), + "num_endusers_found": len(endusers_to_reset) + if endusers_to_reset + else 0, + "endusers_found": json.dumps( + endusers_to_reset, indent=4, default=str + ), + }, + ) + ) + verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e) + async def reset_budget_for_litellm_keys(self): """ Resets the budget for all the litellm keys @@ -355,6 +492,46 @@ class ResetBudgetJob: ) return user + @staticmethod + async def _reset_budget_for_enduser( + enduser: LiteLLM_EndUserTable, + ) -> Optional[LiteLLM_EndUserTable]: + try: + enduser.spend = 0.0 + except Exception as e: + verbose_proxy_logger.exception( + "Error resetting budget for enduser: %s. Item: %s", e, enduser + ) + raise e + return enduser + + @staticmethod + async def _reset_budget_reset_at_date( + budget: LiteLLM_BudgetTableFull, current_time: datetime + ) -> Optional[LiteLLM_BudgetTableFull]: + try: + if budget.budget_duration is not None: + duration_s = duration_in_seconds(duration=budget.budget_duration) + + # Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account + if ( + budget.budget_reset_at is None + and budget.created_at + timedelta(seconds=duration_s) > current_time + ): + budget.budget_reset_at = budget.created_at + timedelta( + seconds=duration_s + ) + else: + budget.budget_reset_at = current_time + timedelta( + seconds=duration_s + ) + except Exception as e: + verbose_proxy_logger.exception( + "Error resetting budget_reset_at for budget: %s. Item: %s", e, budget + ) + raise e + return budget + @staticmethod async def _reset_budget_for_key( key: LiteLLM_VerificationToken, current_time: datetime diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 65b0156afe..54afdc2c32 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -12,8 +12,10 @@ All /budget management endpoints """ #### BUDGET TABLE MANAGEMENT #### +from datetime import timedelta from fastapi import APIRouter, Depends, HTTPException +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.utils import jsonify_object @@ -51,6 +53,12 @@ async def new_budget( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) + # if no budget_reset_at date is set, but a budget_duration is given, then set budget_reset_at initially to the first completed duration interval in future + if budget_obj.budget_reset_at is None and budget_obj.budget_duration is not None: + budget_obj.budget_reset_at = datetime.utcnow() + timedelta( + seconds=duration_in_seconds(duration=budget_obj.budget_duration) + ) + budget_obj_json = budget_obj.model_dump(exclude_none=True) budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries response = await prisma_client.db.litellm_budgettable.create( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6e8c65710d..7fb4ed9ebe 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1410,6 +1410,8 @@ class PrismaClient: "key", "config", "spend", + "enduser", + "budget", "team", "user_notification", "combined_view", @@ -1424,6 +1426,7 @@ class PrismaClient: ] = None, # pagination, number of rows to getch when find_all==True parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, + budget_id_list: Optional[List[str]] = None, ): args_passed_in = locals() start_time = time.time() @@ -1603,6 +1606,29 @@ class PrismaClient: order={"startTime": "desc"}, ) return response + elif table_name == "budget" and reset_at is not None: + if query_type == "find_all": + response = await self.db.litellm_budgettable.find_many( + where={ # type:ignore + "OR": [ + { + "AND": [ + {"budget_reset_at": None}, + {"NOT": {"budget_duration": None}}, + ] + }, + {"budget_reset_at": {"lt": reset_at}}, + ] + } + ) + return response + + elif table_name == "enduser" and budget_id_list is not None: + if query_type == "find_all": + response = await self.db.litellm_endusertable.find_many( + where={"budget_id": {"in": budget_id_list}} + ) + return response elif table_name == "team": if query_type == "find_unique": response = await self.db.litellm_teamtable.find_unique( @@ -1916,7 +1942,9 @@ class PrismaClient: user_id: Optional[str] = None, team_id: Optional[str] = None, query_type: Literal["update", "update_many"] = "update", - table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + table_name: Optional[ + Literal["user", "key", "config", "spend", "team", "enduser", "budget"] + ] = None, update_key_values: Optional[dict] = None, update_key_values_custom_query: Optional[dict] = None, ): @@ -2083,6 +2111,68 @@ class PrismaClient: verbose_proxy_logger.info( "\033[91m" + "DB User Table Batch update succeeded" + "\033[0m" ) + elif ( + table_name is not None + and table_name == "enduser" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for enduser in data_list: + try: + data_json = self.jsonify_object( + data=enduser.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object(data=enduser.dict()) + batcher.litellm_endusertable.upsert( + where={"user_id": enduser.user_id}, # type: ignore + data={ + "create": {**data_json}, # type: ignore + "update": { + **data_json # type: ignore + }, # just update end-user-specified values, if it already exists + }, + ) + await batcher.commit() + verbose_proxy_logger.info( + "\033[91m" + "DB End User Table Batch update succeeded" + "\033[0m" + ) + elif ( + table_name is not None + and table_name == "budget" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for budget in data_list: + try: + data_json = self.jsonify_object( + data=budget.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object(data=budget.dict()) + batcher.litellm_budgettable.upsert( + where={"budget_id": budget.budget_id}, # type: ignore + data={ + "create": {**data_json}, # type: ignore + "update": { + **data_json # type: ignore + }, # just update end-user-specified values, if it already exists + }, + ) + await batcher.commit() + verbose_proxy_logger.info( + "\033[91m" + "DB Budget Table Batch update succeeded" + "\033[0m" + ) elif ( table_name is not None and table_name == "team" diff --git a/tests/litellm/proxy/common_utils/test_reset_budget_job.py b/tests/litellm/proxy/common_utils/test_reset_budget_job.py index bb4af00d78..886ca34b71 100644 --- a/tests/litellm/proxy/common_utils/test_reset_budget_job.py +++ b/tests/litellm/proxy/common_utils/test_reset_budget_job.py @@ -1,12 +1,9 @@ import asyncio -import json import os import sys -import time -from datetime import datetime, timedelta +from datetime import datetime, timezone import pytest -from fastapi.testclient import TestClient sys.path.insert( 0, os.path.abspath("../../..") @@ -20,8 +17,14 @@ from litellm.proxy.utils import ProxyLogging # Mock classes for testing class MockPrismaClient: def __init__(self): - self.data = {"key": [], "user": [], "team": []} - self.updated_data = {"key": [], "user": [], "team": []} + self.data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []} + self.updated_data = { + "key": [], + "user": [], + "team": [], + "budget": [], + "enduser": [], + } async def get_data(self, table_name, query_type, **kwargs): return self.data.get(table_name, []) @@ -145,9 +148,48 @@ def test_reset_budget_for_team(reset_budget_job, mock_prisma_client): assert updated_team.budget_reset_at > now +def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client): + # Setup test data + now = datetime.now(timezone.utc) + test_budget = type( + "LiteLLM_BudgetTable", + (), + { + "max_budget": 500.0, + "budget_duration": "1d", + "budget_reset_at": now, + "budget_id": "test-budget-1", + }, + ) + + test_enduser = type( + "LiteLLM_EndUserTable", + (), + { + "spend": 20.0, + "litellm_budget_table": test_budget, + "user_id": "test-enduser-1", + }, + ) + + mock_prisma_client.data["budget"] = [test_budget] + mock_prisma_client.data["enduser"] = [test_enduser] + + # Run the test + asyncio.run(reset_budget_job.reset_budget_for_litellm_endusers()) + + # Verify results + assert len(mock_prisma_client.updated_data["enduser"]) == 1 + assert len(mock_prisma_client.updated_data["budget"]) == 1 + updated_enduser = mock_prisma_client.updated_data["enduser"][0] + updated_budget = mock_prisma_client.updated_data["budget"][0] + assert updated_enduser.spend == 0.0 + assert updated_budget.budget_reset_at > now + + def test_reset_budget_all(reset_budget_job, mock_prisma_client): # Setup test data - now = datetime.utcnow() + now = datetime.now(timezone.utc) # Create test objects for all three types test_key = type( @@ -183,9 +225,32 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client): }, ) + test_budget = type( + "LiteLLM_BudgetTable", + (), + { + "max_budget": 500.0, + "budget_duration": "1d", + "budget_reset_at": now, + "budget_id": "test-budget-1", + }, + ) + + test_enduser = type( + "LiteLLM_EndUserTable", + (), + { + "spend": 20.0, + "litellm_budget_table": test_budget, + "user_id": "test-enduser-1", + }, + ) + mock_prisma_client.data["key"] = [test_key] mock_prisma_client.data["user"] = [test_user] mock_prisma_client.data["team"] = [test_team] + mock_prisma_client.data["budget"] = [test_budget] + mock_prisma_client.data["enduser"] = [test_enduser] # Run the test asyncio.run(reset_budget_job.reset_budget()) @@ -194,8 +259,11 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client): assert len(mock_prisma_client.updated_data["key"]) == 1 assert len(mock_prisma_client.updated_data["user"]) == 1 assert len(mock_prisma_client.updated_data["team"]) == 1 + assert len(mock_prisma_client.updated_data["enduser"]) == 1 + assert len(mock_prisma_client.updated_data["budget"]) == 1 # Check that all spends were reset to 0 assert mock_prisma_client.updated_data["key"][0].spend == 0.0 assert mock_prisma_client.updated_data["user"][0].spend == 0.0 assert mock_prisma_client.updated_data["team"][0].spend == 0.0 + assert mock_prisma_client.updated_data["enduser"][0].spend == 0.0 diff --git a/tests/litellm_utils_tests/test_proxy_budget_reset.py b/tests/litellm_utils_tests/test_proxy_budget_reset.py index 1fbe493d8d..a3c9523941 100644 --- a/tests/litellm_utils_tests/test_proxy_budget_reset.py +++ b/tests/litellm_utils_tests/test_proxy_budget_reset.py @@ -1,32 +1,22 @@ +import asyncio import os import sys -import time -import traceback -import uuid -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from dotenv import load_dotenv -import json -import asyncio load_dotenv() import os -import tempfile -from uuid import uuid4 + +from litellm.proxy._types import LiteLLM_BudgetTableFull sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob -from litellm.proxy._types import ( - LiteLLM_VerificationToken, - LiteLLM_UserTable, - LiteLLM_TeamTable, -) -from litellm.types.services import ServiceTypes # Note: In our "fake" items we use dicts with fields that our fake reset functions modify. # In a real-world scenario, these would be instances of LiteLLM_VerificationToken, LiteLLM_UserTable, etc. @@ -180,6 +170,106 @@ async def test_reset_budget_users_partial_failure(): ) +@pytest.mark.asyncio +async def test_reset_budget_endusers_partial_failure(): + """ + Test that if one enduser fails to reset, the reset loop still processes the other endusers. + We simulate six endsers where the first fails and the others are updated. + """ + user1 = { + "user_id": "user1", + "spend": 20.0, + "budget_id": "budget1", + } # Will trigger simulated failure + user2 = { + "user_id": "user2", + "spend": 25.0, + "budget_id": "budget1", + } # Should be updated + user3 = { + "user_id": "user3", + "spend": 30.0, + "budget_id": "budget1", + } # Should be updated + user4 = { + "user_id": "user4", + "spend": 35.0, + "budget_id": "budget1", + } # Should be updated + user5 = { + "user_id": "user5", + "spend": 40.0, + "budget_id": "budget1", + } # Should be updated + user6 = { + "user_id": "user6", + "spend": 45.0, + "budget_id": "budget1", + } # Should be updated + + budget1 = LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) + + prisma_client = MagicMock() + + async def get_data_mock(table_name, *args, **kwargs): + if table_name == "budget": + return [budget1] + elif table_name == "enduser": + return [user1, user2, user3, user4, user5, user6] + return [] + + prisma_client.get_data = AsyncMock() + prisma_client.get_data.side_effect = get_data_mock + + prisma_client.update_data = AsyncMock() + + proxy_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock() + proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock() + + job = ResetBudgetJob(proxy_logging_obj, prisma_client) + + async def fake_reset_enduser(enduser): + if enduser["user_id"] == "user1": + raise Exception("Simulated failure for user1") + enduser["spend"] = 0.0 + return enduser + + with patch.object( + ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser + ) as mock_reset_enduser: + await job.reset_budget_for_litellm_endusers() + await asyncio.sleep(0.1) + + assert mock_reset_enduser.call_count == 6 + assert prisma_client.update_data.await_count == 2 + update_call = prisma_client.update_data.call_args + assert update_call.kwargs.get("table_name") == "enduser" + updated_users = update_call.kwargs.get("data_list", []) + assert len(updated_users) == 5 + assert updated_users[0]["user_id"] == "user2" + assert updated_users[1]["user_id"] == "user3" + assert updated_users[2]["user_id"] == "user4" + assert updated_users[3]["user_id"] == "user5" + assert updated_users[4]["user_id"] == "user6" + + failure_hook_calls = ( + proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list + ) + assert any( + call.kwargs.get("call_type") == "reset_budget_endusers" + for call in failure_hook_calls + ) + + @pytest.mark.asyncio async def test_reset_budget_teams_partial_failure(): """ @@ -263,6 +353,15 @@ async def test_reset_budget_continues_other_categories_on_failure(): user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180} team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180} + enduser1 = {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"} + budget1 = LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) prisma_client = MagicMock() @@ -273,6 +372,10 @@ async def test_reset_budget_continues_other_categories_on_failure(): return [user1, user2] elif table_name == "team": return [team1, team2] + elif table_name == "budget": + return [budget1] + elif table_name == "enduser": + return [enduser1] return [] prisma_client.get_data = AsyncMock(side_effect=fake_get_data) @@ -308,13 +411,19 @@ async def test_reset_budget_continues_other_categories_on_failure(): ).isoformat() return team + async def fake_reset_enduser(enduser): + enduser["spend"] = 0.0 + return enduser + with patch.object( ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key ) as mock_reset_key, patch.object( ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user ) as mock_reset_user, patch.object( ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team - ) as mock_reset_team: + ) as mock_reset_team, patch.object( + ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser + ) as mock_reset_enduser: # Call the overall reset_budget method. await job.reset_budget() await asyncio.sleep(0.1) @@ -323,10 +432,10 @@ async def test_reset_budget_continues_other_categories_on_failure(): called_tables = { call.kwargs.get("table_name") for call in prisma_client.get_data.await_args_list } - assert called_tables == {"key", "user", "team"} + assert called_tables == {"key", "user", "team", "budget", "enduser"} - # Verify that update_data was called three times (one per category) - assert prisma_client.update_data.await_count == 3 + # Verify that update_data was called three times (one per category, enduser update includes two) + assert prisma_client.update_data.await_count == 5 calls = prisma_client.update_data.await_args_list # Check keys update: both keys succeed. @@ -346,9 +455,14 @@ async def test_reset_budget_continues_other_categories_on_failure(): assert teams_call.kwargs.get("table_name") == "team" assert len(teams_call.kwargs.get("data_list", [])) == 2 + # Check enduser update: enduser succeed. + enduser_call = calls[4] + assert enduser_call.kwargs.get("table_name") == "enduser" + assert len(enduser_call.kwargs.get("data_list", [])) == 1 + # --------------------------------------------------------------------------- -# Additional tests for service logger behavior (keys, users, teams) +# Additional tests for service logger behavior (keys, users, teams, endusers) # --------------------------------------------------------------------------- @@ -395,9 +509,10 @@ async def test_service_logger_keys_success(): # Verify success hook call proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_keys_found") == len(keys) assert event_metadata.get("num_keys_updated") == len(keys) @@ -456,9 +571,10 @@ async def test_service_logger_keys_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_keys_found") == len(keys) keys_found_str = event_metadata.get("keys_found", "") @@ -508,9 +624,10 @@ async def test_service_logger_users_success(): mock_verbose_exc.assert_not_called() proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_users_found") == len(users) assert event_metadata.get("num_users_updated") == len(users) @@ -567,9 +684,10 @@ async def test_service_logger_users_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_users_found") == len(users) users_found_str = event_metadata.get("users_found", "") @@ -618,9 +736,10 @@ async def test_service_logger_teams_success(): mock_verbose_exc.assert_not_called() proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_teams_found") == len(teams) assert event_metadata.get("num_teams_updated") == len(teams) @@ -677,11 +796,156 @@ async def test_service_logger_teams_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_teams_found") == len(teams) teams_found_str = event_metadata.get("teams_found", "") assert "team1" in teams_found_str proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called() + + +@pytest.mark.asyncio +async def test_service_logger_endusers_success(): + """ + Test that when resetting endusers succeeds the service logger success hook is called with + the correct metadata and no exception is logged. + """ + endusers = [ + {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, + {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"}, + ] + budgets = [ + LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) + ] + + async def fake_get_data(*, table_name, query_type, **kwargs): + if table_name == "budget": + return budgets + elif table_name == "enduser": + return endusers + return [] + + prisma_client = MagicMock() + prisma_client.get_data = AsyncMock(side_effect=fake_get_data) + prisma_client.update_data = AsyncMock() + + proxy_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock() + proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock() + + job = ResetBudgetJob(proxy_logging_obj, prisma_client) + + async def fake_reset_enduser(enduser): + enduser["spend"] = 0.0 + return enduser + + with patch.object( + ResetBudgetJob, + "_reset_budget_for_enduser", + side_effect=fake_reset_enduser, + ): + with patch( + "litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception" + ) as mock_verbose_exc: + await job.reset_budget_for_litellm_endusers() + await asyncio.sleep(0.1) + mock_verbose_exc.assert_not_called() + + proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args + event_metadata = kwargs.get("event_metadata", {}) + assert event_metadata.get("num_budgets_found") == len(budgets) + assert event_metadata.get("num_endusers_found") == len(endusers) + assert event_metadata.get("num_endusers_updated") == len(endusers) + assert event_metadata.get("num_endusers_failed") == 0 + proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called() + + +@pytest.mark.asyncio +async def test_service_logger_users_failure(): + """ + Test that a failure during enduser reset calls the failure hook with appropriate metadata, + logs the exception, and does not call the success hook. + """ + endusers = [ + {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, + {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"}, + ] + budgets = [ + LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) + ] + + async def fake_get_data(*, table_name, query_type, **kwargs): + if table_name == "budget": + return budgets + elif table_name == "enduser": + return endusers + return [] + + prisma_client = MagicMock() + prisma_client.get_data = AsyncMock(side_effect=fake_get_data) + prisma_client.update_data = AsyncMock() + + proxy_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock() + proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock() + + job = ResetBudgetJob(proxy_logging_obj, prisma_client) + + async def fake_reset_enduser(enduser): + if enduser["user_id"] == "user1": + raise Exception("Simulated failure for user1") + enduser["spend"] = 0.0 + return enduser + + with patch.object( + ResetBudgetJob, + "_reset_budget_for_enduser", + side_effect=fake_reset_enduser, + ): + with patch( + "litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception" + ) as mock_verbose_exc: + await job.reset_budget_for_litellm_endusers() + await asyncio.sleep(0.1) + # Verify exception logging + assert mock_verbose_exc.call_count >= 1 + # Verify exception was logged with correct message + assert any( + "Failed to reset budget for enduser" in str(call.args) + for call in mock_verbose_exc.call_args_list + ) + + proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args + event_metadata = kwargs.get("event_metadata", {}) + assert event_metadata.get("num_budgets_found") == len(budgets) + assert event_metadata.get("num_endusers_found") == len(endusers) + endusers_found_str = event_metadata.get("endusers_found", "") + assert "user1" in endusers_found_str + proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called() diff --git a/tests/local_testing/test_enduser_spend_reset.py b/tests/local_testing/test_enduser_spend_reset.py new file mode 100644 index 0000000000..832ba9b324 --- /dev/null +++ b/tests/local_testing/test_enduser_spend_reset.py @@ -0,0 +1,266 @@ +import asyncio +import os +import sys +import uuid +from datetime import datetime, timedelta, timezone + +import aiohttp +import pytest +import pytest_asyncio +from dotenv import load_dotenv + +from litellm.caching.caching import DualCache +from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob +from litellm.proxy.utils import PrismaClient, ProxyLogging + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +load_dotenv() + + +async def create_budget(session, data): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + budget_id = response_data["budget_id"] + print(f"Created Budget {budget_id}") + return response_data + + +async def create_end_user(prisma_client, session, user_id, budget_id, spend=None): + url = "http://0.0.0.0:4000/end_user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "user_id": user_id, + "budget_id": budget_id, + } + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + end_user_id = response_data["user_id"] + print(f"Created End User {end_user_id}") + + if spend is not None: + end_users = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[budget_id], + ) + end_user = [user for user in end_users if user.user_id == user_id][0] + end_user.spend = spend + await prisma_client.update_data( + query_type="update_many", + data_list=[end_user], + table_name="enduser", + ) + + return response_data + + +async def delete_budget(session, budget_id): + url = "http://0.0.0.0:4000/budget/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"id": budget_id} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted Budget {budget_id}") + + +async def delete_end_user(session, user_id): + url = "http://0.0.0.0:4000/end_user/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"user_ids": [user_id]} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted End User {user_id}") + + +@pytest.fixture +def prisma_client(): + proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + return prisma_client + + +class MockProxyLogging: + class MockServiceLogging: + async def async_service_success_hook(self, **kwargs): + pass + + async def async_service_failure_hook(self, **kwargs): + pass + + def __init__(self): + self.service_logging_obj = self.MockServiceLogging() + + +@pytest.fixture +def mock_proxy_logging(): + return MockProxyLogging() + + +@pytest.fixture +def reset_budget_job(prisma_client, mock_proxy_logging): + return ResetBudgetJob( + proxy_logging_obj=mock_proxy_logging, prisma_client=prisma_client + ) + + +@pytest_asyncio.fixture +async def budget_and_enduser_setup(prisma_client): + """ + Fixture to set up budgets and end users for testing and clean them up afterward. + + This fixture performs the following: + - Creates two budgets: + * Budget X with a short duration ("5s"). + * Budget Y with a long duration ("30d"). + - Stores the initial 'budget_reset_at' timestamps for both budgets. + - Creates three end users: + * End Users A and B are associated with Budget X and are given initial spend values. + * End User C is associated with Budget Y with an initial spend. + - After the test (after the yield), the created end users and budgets are deleted. + """ + await prisma_client.connect() + + async with aiohttp.ClientSession() as session: + # Create budgets + id_budget_x = f"budget-{uuid.uuid4()}" + data_budget_x = { + "budget_id": id_budget_x, + "budget_duration": "5s", + "max_budget": 2, + } + id_budget_y = f"budget-{uuid.uuid4()}" + data_budget_y = { + "budget_id": id_budget_y, + "budget_duration": "30d", + "max_budget": 1, + } + response_budget_x = await create_budget(session, data_budget_x) + initial_budget_x_reset_at_date = datetime.fromisoformat( + response_budget_x["budget_reset_at"] + ) + response_budget_y = await create_budget(session, data_budget_y) + initial_budget_y_reset_at_date = datetime.fromisoformat( + response_budget_y["budget_reset_at"] + ) + + # Create end users + id_end_user_a = f"test-user-{uuid.uuid4()}" + id_end_user_b = f"test-user-{uuid.uuid4()}" + id_end_user_c = f"test-user-{uuid.uuid4()}" + await create_end_user( + prisma_client, session, id_end_user_a, id_budget_x, spend=0.16 + ) + await create_end_user( + prisma_client, session, id_end_user_b, id_budget_x, spend=0.04 + ) + await create_end_user( + prisma_client, session, id_end_user_c, id_budget_y, spend=0.2 + ) + + # Bundle data needed for the test + setup_data = { + "budgets": { + "id_budget_x": id_budget_x, + "id_budget_y": id_budget_y, + "initial_budget_x_reset_at_date": initial_budget_x_reset_at_date, + "initial_budget_y_reset_at_date": initial_budget_y_reset_at_date, + }, + "end_users": { + "id_end_user_a": id_end_user_a, + "id_end_user_b": id_end_user_b, + "id_end_user_c": id_end_user_c, + }, + } + + # Provide the setup data to the test + yield setup_data + + # Clean-up: Delete the created test data + await delete_end_user(session, id_end_user_a) + await delete_end_user(session, id_end_user_b) + await delete_end_user(session, id_end_user_c) + await delete_budget(session, id_budget_x) + await delete_budget(session, id_budget_y) + + +@pytest.mark.asyncio +async def test_reset_budget_for_endusers( + reset_budget_job, prisma_client, budget_and_enduser_setup +): + """ + Test the part "Reset End-User (Customer) Spend and corresponding Budget duration" in reset_budget function. + + This test uses the budget_and_enduser_setup fixture to create budgets and end users, + waits for the short-duration budget to expire, calls reset_budget, and verifies that: + - End users associated with the short-duration budget X have their spend reset to 0. + - The budget_reset_at timestamp for the short-duration budget X is updated, + while the long-duration budget Y remains unchanged. + """ + + # Unpack the required data from the fixture + budgets = budget_and_enduser_setup["budgets"] + end_users = budget_and_enduser_setup["end_users"] + + id_budget_x = budgets["id_budget_x"] + id_budget_y = budgets["id_budget_y"] + initial_budget_x_reset_at_date = budgets["initial_budget_x_reset_at_date"] + initial_budget_y_reset_at_date = budgets["initial_budget_y_reset_at_date"] + + id_end_user_a = end_users["id_end_user_a"] + id_end_user_b = end_users["id_end_user_b"] + id_end_user_c = end_users["id_end_user_c"] + + # Wait for Budget X to expire (short duration "5s" plus a small buffer) + await asyncio.sleep(6) + + # Call the reset_budget function: + # It should reset the spend values for end users associated with Budget X. + await reset_budget_job.reset_budget_for_litellm_endusers() + + # Retrieve updated data for end users + updated_end_users = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[id_budget_x, id_budget_y], + ) + # Retrieve updated data for budgets + updated_budgets = await prisma_client.get_data( + table_name="budget", + query_type="find_all", + reset_at=datetime.now(timezone.utc) + timedelta(days=31), + ) + + # Assertions for end users + user_a = [user for user in updated_end_users if user.user_id == id_end_user_a][0] + user_b = [user for user in updated_end_users if user.user_id == id_end_user_b][0] + user_c = [user for user in updated_end_users if user.user_id == id_end_user_c][0] + + assert user_a.spend == 0, "Spend for end_user_a was not reset to 0" + assert user_b.spend == 0, "Spend for end_user_b was not reset to 0" + assert user_c.spend > 0, "Spend for end_user_c should not be reset" + + # Assertions for budgets + budget_x = [ + budget for budget in updated_budgets if budget.budget_id == id_budget_x + ][0] + budget_y = [ + budget for budget in updated_budgets if budget.budget_id == id_budget_y + ][0] + + assert ( + budget_x.budget_reset_at > initial_budget_x_reset_at_date + ), "Budget X budget_reset_at was not updated" + assert ( + budget_y.budget_reset_at == initial_budget_y_reset_at_date + ), "Budget Y budget_reset_at should remain unchanged" diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 1281d50863..c974cbc7d4 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1792,4 +1792,4 @@ async def test_get_admin_team_ids( where={"team_id": {"in": user_info.teams}} ) else: - mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called() + mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called() \ No newline at end of file diff --git a/tests/test_budget_management.py b/tests/test_budget_management.py new file mode 100644 index 0000000000..8175a9b1df --- /dev/null +++ b/tests/test_budget_management.py @@ -0,0 +1,90 @@ +# What is this? +## Unit tests for the /budget/* endpoints +import uuid +from datetime import datetime, timedelta + +import aiohttp +import pytest +import pytest_asyncio + + +async def delete_budget(session, budget_id): + url = "http://0.0.0.0:4000/budget/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"id": budget_id} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted Budget {budget_id}") + + +async def create_budget(session, data): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + budget_id = response_data["budget_id"] + print(f"Created Budget {budget_id}") + return response_data + + +@pytest_asyncio.fixture +async def budget_setup(): + """ + Fixture to create a budget for testing and clean it up afterward. + + This fixture performs the following steps: + 1. Opens an aiohttp ClientSession. + 2. Generates a random budget_id and defines the budget data (duration: 1 day, max_budget: 0.02). + 3. Calls create_budget to create the budget. + 4. Yields the budget_response (a dict) for use in the test. + 5. After the test completes, deletes the created budget by calling delete_budget. + + Returns: + dict: The JSON response from create_budget, which includes the created budget's data. + """ + + async with aiohttp.ClientSession() as session: + # Generate a unique budget_id and define the budget data. + budget_id = f"budget-{uuid.uuid4()}" + data = {"budget_id": budget_id, "budget_duration": "1d", "max_budget": 0.02} + budget_response = await create_budget(session, data) + + # Yield the response so the test can use it. + yield budget_response + + # After the test, delete the created budget to clean up. + await delete_budget(session, budget_id) + + +@pytest.mark.asyncio +async def test_create_budget_with_duration(budget_setup): + """ + Test creating a budget with a specified duration and verify that the 'budget_reset_at' + timestamp is correctly calculated as 'created_at' plus the budget duration (one day). + + This test uses the budget_setup fixture, which handles both the creation and cleanup of the budget. + """ + + # Verify that the response includes a 'budget_reset_at' timestamp. + assert ( + budget_setup["budget_reset_at"] is not None + ), "The budget_reset_at field should not be None" + + # Calculate the expected reset time: created_at + 1 day. + expected_reset_at_date = datetime.fromisoformat( + budget_setup["created_at"] + ) + timedelta(days=1) + + # Allow for a small tolerance in seconds for the timestamp calculation. + tolerance_seconds = 3 + actual_reset_at_date = datetime.fromisoformat(budget_setup["budget_reset_at"]) + time_difference = abs( + (actual_reset_at_date - expected_reset_at_date).total_seconds() + ) + + assert time_difference <= tolerance_seconds, ( + f"Expected budget_reset_at to be within {tolerance_seconds} seconds of {expected_reset_at_date}, " + f"but the difference was {time_difference} seconds." + )