From 0b6d941dc2ca1be03255942070c21084b5e51f15 Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Mon, 10 Feb 2025 12:59:36 +0100 Subject: [PATCH 01/10] feat: update enduser spend and budget reset date based on budget duration --- litellm/proxy/_types.py | 4 + .../budget_management_endpoints.py | 9 +- litellm/proxy/utils.py | 120 +++++++++++++++++- 3 files changed, 129 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 713925c638..410d4005f3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -764,6 +764,10 @@ class BudgetNewRequest(LiteLLMPydanticObjectBase): default=None, description="Max budget for each model (e.g. {'gpt-4o': {'max_budget': '0.0000001', 'budget_duration': '1d', 'tpm_limit': 1000, 'rpm_limit': 1000}})", ) + budget_reset_at: Optional[datetime] = Field( + default=None, + description="Datetime when the budget is reset", + ) class BudgetRequest(LiteLLMPydanticObjectBase): diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20aa1c6bbf..54afdc2c32 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -12,8 +12,10 @@ All /budget management endpoints """ #### BUDGET TABLE MANAGEMENT #### +from datetime import timedelta from fastapi import APIRouter, Depends, HTTPException +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.utils import jsonify_object @@ -51,6 +53,12 @@ async def new_budget( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) + # if no budget_reset_at date is set, but a budget_duration is given, then set budget_reset_at initially to the first completed duration interval in future + if budget_obj.budget_reset_at is None and budget_obj.budget_duration is not None: + budget_obj.budget_reset_at = datetime.utcnow() + timedelta( + seconds=duration_in_seconds(duration=budget_obj.budget_duration) + ) + budget_obj_json = budget_obj.model_dump(exclude_none=True) budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries response = await prisma_client.db.litellm_budgettable.create( @@ -197,7 +205,6 @@ async def budget_settings( for field_name, field_info in BudgetNewRequest.model_fields.items(): if field_name in allowed_args: - _stored_in_db = True _response_obj = ConfigList( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 51f235522d..b43d6becbc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -8,7 +8,7 @@ import smtplib import threading import time import traceback -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload @@ -482,7 +482,6 @@ class ProxyLogging: try: for callback in litellm.callbacks: - _callback = None if isinstance(callback, str): _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( @@ -1290,6 +1289,8 @@ class PrismaClient: "key", "config", "spend", + "enduser", + "budget", "team", "user_notification", "combined_view", @@ -1304,6 +1305,7 @@ class PrismaClient: ] = None, # pagination, number of rows to getch when find_all==True parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, + budget_id_list: Optional[List[str]] = None, ): args_passed_in = locals() start_time = time.time() @@ -1482,6 +1484,21 @@ class PrismaClient: order={"startTime": "desc"}, ) return response + elif table_name == "budget" and reset_at is not None: + if query_type == "find_all": + response = await self.db.litellm_budgettable.find_many( + where={ # type:ignore + "budget_reset_at": {"lt": reset_at} + } + ) + return response + + elif table_name == "enduser" and budget_id_list is not None: + if query_type == "find_all": + response = await self.db.litellm_endusertable.find_many( + where={"budget_id": {"in": budget_id_list}} + ) + return response elif table_name == "team": if query_type == "find_unique": response = await self.db.litellm_teamtable.find_unique( @@ -1799,7 +1816,9 @@ class PrismaClient: user_id: Optional[str] = None, team_id: Optional[str] = None, query_type: Literal["update", "update_many"] = "update", - table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + table_name: Optional[ + Literal["user", "key", "config", "spend", "team", "enduser", "budget"] + ] = None, update_key_values: Optional[dict] = None, update_key_values_custom_query: Optional[dict] = None, ): @@ -1966,6 +1985,68 @@ class PrismaClient: verbose_proxy_logger.info( "\033[91m" + "DB User Table Batch update succeeded" + "\033[0m" ) + elif ( + table_name is not None + and table_name == "enduser" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for enduser in data_list: + try: + data_json = self.jsonify_object( + data=enduser.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object(data=enduser.dict()) + batcher.litellm_endusertable.upsert( + where={"user_id": enduser.user_id}, # type: ignore + data={ + "create": {**data_json}, # type: ignore + "update": { + **data_json # type: ignore + }, # just update end-user-specified values, if it already exists + }, + ) + await batcher.commit() + verbose_proxy_logger.info( + "\033[91m" + "DB End User Table Batch update succeeded" + "\033[0m" + ) + elif ( + table_name is not None + and table_name == "budget" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for budget in data_list: + try: + data_json = self.jsonify_object( + data=budget.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object(data=budget.dict()) + batcher.litellm_budgettable.upsert( + where={"budget_id": budget.budget_id}, # type: ignore + data={ + "create": {**data_json}, # type: ignore + "update": { + **data_json # type: ignore + }, # just update end-user-specified values, if it already exists + }, + ) + await batcher.commit() + verbose_proxy_logger.info( + "\033[91m" + "DB Budget Table Batch update succeeded" + "\033[0m" + ) elif ( table_name is not None and table_name == "team" @@ -2408,6 +2489,39 @@ async def reset_budget(prisma_client: PrismaClient): query_type="update_many", data_list=users_to_reset, table_name="user" ) + ## Reset End-User (Customer) Spend and corresponding Budget duration + now = datetime.now(timezone.utc) + + budgets_to_reset = await prisma_client.get_data( + table_name="budget", query_type="find_all", reset_at=now + ) + budget_id_list_to_reset_enduser = [] + if budgets_to_reset is not None and len(budgets_to_reset) > 0: + for budget in budgets_to_reset: + budget_id_list_to_reset_enduser.append(budget.budget_id) + duration_s = duration_in_seconds(duration=budget.budget_duration) + budget.budget_reset_at = now + timedelta(seconds=duration_s) + await prisma_client.update_data( + query_type="update_many", + data_list=budgets_to_reset, + table_name="budget", + ) + + endusers_to_reset = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=budget_id_list_to_reset_enduser, + ) + + if endusers_to_reset is not None and len(endusers_to_reset) > 0: + for enduser in endusers_to_reset: + enduser.spend = 0.0 + await prisma_client.update_data( + query_type="update_many", + data_list=endusers_to_reset, + table_name="enduser", + ) + ## Reset Team Budget now = datetime.utcnow() teams_to_reset = await prisma_client.get_data( From 6b5469cd45674e63d2bec9476f9f1b2b47617653 Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Tue, 11 Feb 2025 15:20:45 +0100 Subject: [PATCH 02/10] adding tests for budget creation and reset_budget function --- tests/proxy_unit_tests/test_proxy_utils.py | 305 ++++++++++++++++++--- tests/test_budget_management.py | 90 ++++++ 2 files changed, 363 insertions(+), 32 deletions(-) create mode 100644 tests/test_budget_management.py diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 36f9b6652f..081db6638c 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1,27 +1,38 @@ import asyncio +import json import os import sys +import uuid +from datetime import datetime from typing import Any, Dict -from unittest.mock import Mock -from litellm.proxy.utils import _get_redoc_url, _get_docs_url -import json +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import aiohttp import pytest +import pytest_asyncio from fastapi import Request -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path import litellm -from unittest.mock import MagicMock, patch, AsyncMock - +from litellm.caching.caching import DualCache from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import is_request_body_safe from litellm.proxy.litellm_pre_call_utils import ( _get_dynamic_logging_metadata, add_litellm_data_to_request, ) +from litellm.proxy.utils import ( + PrismaClient, + ProxyLogging, + _get_docs_url, + _get_redoc_url, + reset_budget, +) from litellm.types.utils import SupportedCacheControls +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + @pytest.fixture def mock_request(monkeypatch): @@ -488,8 +499,9 @@ def test_reading_openai_org_id_from_headers(): ) def test_add_litellm_data_for_backend_llm_call(headers, expected_data): import json - from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" @@ -509,8 +521,8 @@ def test_foward_litellm_user_info_to_backend_llm_call(): litellm.add_user_information_to_llm_headers = True - from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" @@ -531,10 +543,10 @@ def test_foward_litellm_user_info_to_backend_llm_call(): def test_update_internal_user_params(): + from litellm.proxy._types import NewUserRequest from litellm.proxy.management_endpoints.internal_user_endpoints import ( _update_internal_new_user_params, ) - from litellm.proxy._types import NewUserRequest litellm.default_internal_user_params = { "max_budget": 100, @@ -558,9 +570,10 @@ def test_update_internal_user_params(): @pytest.mark.asyncio async def test_proxy_config_update_from_db(): - from litellm.proxy.proxy_server import ProxyConfig from pydantic import BaseModel + from litellm.proxy.proxy_server import ProxyConfig + proxy_config = ProxyConfig() pc = AsyncMock() @@ -602,10 +615,10 @@ async def test_proxy_config_update_from_db(): def test_prepare_key_update_data(): + from litellm.proxy._types import UpdateKeyRequest from litellm.proxy.management_endpoints.key_management_endpoints import ( prepare_key_update_data, ) - from litellm.proxy._types import UpdateKeyRequest existing_key_row = MagicMock() data = UpdateKeyRequest(key="test_key", models=["gpt-4"], duration="120s") @@ -899,9 +912,10 @@ def test_enforced_params_check( def test_get_key_models(): - from litellm.proxy.auth.model_checks import get_key_models from collections import defaultdict + from litellm.proxy.auth.model_checks import get_key_models + user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", @@ -923,9 +937,10 @@ def test_get_key_models(): def test_get_team_models(): - from litellm.proxy.auth.model_checks import get_team_models from collections import defaultdict + from litellm.proxy.auth.model_checks import get_team_models + user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", @@ -1111,9 +1126,10 @@ def test_proxy_config_state_get_config_state_error(): """ Ensures that get_config_state does not raise an error when the config is not a valid dictionary """ - from litellm.proxy.proxy_server import ProxyConfig import threading + from litellm.proxy.proxy_server import ProxyConfig + test_config = { "callback_list": [ { @@ -1248,8 +1264,8 @@ def test_is_allowed_to_make_key_request(): def test_get_model_group_info(): - from litellm.proxy.proxy_server import _get_model_group_info from litellm import Router + from litellm.proxy.proxy_server import _get_model_group_info router = Router( model_list=[ @@ -1277,10 +1293,11 @@ def test_get_model_group_info(): assert len(model_list) == 1 -import pytest import asyncio -from unittest.mock import AsyncMock, patch import json +from unittest.mock import AsyncMock, patch + +import pytest @pytest.fixture @@ -1339,7 +1356,6 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): "litellm.proxy.proxy_server.prisma_client", MockPrismaClientDB(mock_team_data, mock_key_data), ): - from litellm.proxy.management_endpoints.internal_user_endpoints import ( _get_user_info_for_proxy_admin, ) @@ -1353,10 +1369,12 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): def test_custom_openid_response(): - from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor - from litellm.proxy.management_endpoints.ui_sso import JWTHandler - from litellm.proxy._types import LiteLLM_JWTAuth from litellm.caching import DualCache + from litellm.proxy._types import LiteLLM_JWTAuth + from litellm.proxy.management_endpoints.ui_sso import ( + JWTHandler, + generic_response_convertor, + ) jwt_handler = JWTHandler() jwt_handler.update_environment( @@ -1410,10 +1428,11 @@ def test_update_key_request_validation(): def test_get_temp_budget_increase(): - from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase - from litellm.proxy._types import UserAPIKeyAuth from datetime import datetime, timedelta + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase + expiry = datetime.now() + timedelta(days=1) expiry_in_isoformat = expiry.isoformat() @@ -1429,11 +1448,12 @@ def test_get_temp_budget_increase(): def test_update_key_budget_with_temp_budget_increase(): + from datetime import datetime, timedelta + + from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import ( _update_key_budget_with_temp_budget_increase, ) - from litellm.proxy._types import UserAPIKeyAuth - from datetime import datetime, timedelta expiry = datetime.now() + timedelta(days=1) expiry_in_isoformat = expiry.isoformat() @@ -1449,7 +1469,7 @@ def test_update_key_budget_with_temp_budget_increase(): assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200 -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import AsyncMock, MagicMock @pytest.mark.asyncio @@ -1490,17 +1510,18 @@ async def test_health_check_not_called_when_disabled(monkeypatch): }, ) def test_custom_openapi(mock_get_openapi_schema): - from litellm.proxy.proxy_server import custom_openapi - from litellm.proxy.proxy_server import app + from litellm.proxy.proxy_server import app, custom_openapi openapi_schema = custom_openapi() assert openapi_schema is not None -import pytest -from unittest.mock import MagicMock, AsyncMock import asyncio from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + from litellm.proxy.utils import ProxyUpdateSpend @@ -1617,3 +1638,223 @@ def test_provider_specific_header(): "anthropic-beta": "prompt-caching-2024-07-31", }, } + + +async def create_budget(session, data): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + budget_id = response_data["budget_id"] + print(f"Created Budget {budget_id}") + return response_data + + +async def create_end_user(prisma_client, session, user_id, budget_id, spend=None): + url = "http://0.0.0.0:4000/end_user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "user_id": user_id, + "budget_id": budget_id, + } + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + end_user_id = response_data["user_id"] + print(f"Created End User {end_user_id}") + + if spend is not None: + end_users = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[budget_id], + ) + end_user = [user for user in end_users if user.user_id == user_id][0] + end_user.spend = spend + await prisma_client.update_data( + query_type="update_many", + data_list=[end_user], + table_name="enduser", + ) + + return response_data + + +async def delete_budget(session, budget_id): + url = "http://0.0.0.0:4000/budget/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"id": budget_id} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted Budget {budget_id}") + + +async def delete_end_user(session, user_id): + url = "http://0.0.0.0:4000/end_user/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"user_ids": [user_id]} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted End User {user_id}") + + +@pytest.fixture +def prisma_client(): + proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + return prisma_client + + +@pytest_asyncio.fixture +async def budget_and_enduser_setup(prisma_client): + """ + Fixture to set up budgets and end users for testing and clean them up afterward. + + This fixture performs the following: + - Creates two budgets: + * Budget X with a short duration ("5s"). + * Budget Y with a long duration ("30d"). + - Stores the initial 'budget_reset_at' timestamps for both budgets. + - Creates three end users: + * End Users A and B are associated with Budget X and are given initial spend values. + * End User C is associated with Budget Y with an initial spend. + - After the test (after the yield), the created end users and budgets are deleted. + """ + await prisma_client.connect() + + async with aiohttp.ClientSession() as session: + # Create budgets + id_budget_x = f"budget-{uuid.uuid4()}" + data_budget_x = { + "budget_id": id_budget_x, + "budget_duration": "5s", + "max_budget": 2, + } + id_budget_y = f"budget-{uuid.uuid4()}" + data_budget_y = { + "budget_id": id_budget_y, + "budget_duration": "30d", + "max_budget": 1, + } + response_budget_x = await create_budget(session, data_budget_x) + initial_budget_x_reset_at_date = datetime.fromisoformat( + response_budget_x["budget_reset_at"] + ) + response_budget_y = await create_budget(session, data_budget_y) + initial_budget_y_reset_at_date = datetime.fromisoformat( + response_budget_y["budget_reset_at"] + ) + + # Create end users + id_end_user_a = f"test-user-{uuid.uuid4()}" + id_end_user_b = f"test-user-{uuid.uuid4()}" + id_end_user_c = f"test-user-{uuid.uuid4()}" + await create_end_user( + prisma_client, session, id_end_user_a, id_budget_x, spend=0.16 + ) + await create_end_user( + prisma_client, session, id_end_user_b, id_budget_x, spend=0.04 + ) + await create_end_user( + prisma_client, session, id_end_user_c, id_budget_y, spend=0.2 + ) + + # Bundle data needed for the test + setup_data = { + "budgets": { + "id_budget_x": id_budget_x, + "id_budget_y": id_budget_y, + "initial_budget_x_reset_at_date": initial_budget_x_reset_at_date, + "initial_budget_y_reset_at_date": initial_budget_y_reset_at_date, + }, + "end_users": { + "id_end_user_a": id_end_user_a, + "id_end_user_b": id_end_user_b, + "id_end_user_c": id_end_user_c, + }, + } + + # Provide the setup data to the test + yield setup_data + + # Clean-up: Delete the created test data + await delete_end_user(session, id_end_user_a) + await delete_end_user(session, id_end_user_b) + await delete_end_user(session, id_end_user_c) + await delete_budget(session, id_budget_x) + await delete_budget(session, id_budget_y) + + +@pytest.mark.asyncio +async def test_reset_budget_for_endusers(prisma_client, budget_and_enduser_setup): + """ + Test the part "Reset End-User (Customer) Spend and corresponding Budget duration" in reset_budget function. + + This test uses the budget_and_enduser_setup fixture to create budgets and end users, + waits for the short-duration budget to expire, calls reset_budget, and verifies that: + - End users associated with the short-duration budget X have their spend reset to 0. + - The budget_reset_at timestamp for the short-duration budget X is updated, + while the long-duration budget Y remains unchanged. + """ + + # Unpack the required data from the fixture + budgets = budget_and_enduser_setup["budgets"] + end_users = budget_and_enduser_setup["end_users"] + + id_budget_x = budgets["id_budget_x"] + id_budget_y = budgets["id_budget_y"] + initial_budget_x_reset_at_date = budgets["initial_budget_x_reset_at_date"] + initial_budget_y_reset_at_date = budgets["initial_budget_y_reset_at_date"] + + id_end_user_a = end_users["id_end_user_a"] + id_end_user_b = end_users["id_end_user_b"] + id_end_user_c = end_users["id_end_user_c"] + + # Wait for Budget X to expire (short duration "5s" plus a small buffer) + await asyncio.sleep(6) + + # Call the reset_budget function: + # It should reset the spend values for end users associated with Budget X. + await reset_budget(prisma_client) + + # Retrieve updated data for end users + updated_end_users = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[id_budget_x, id_budget_y], + ) + # Retrieve updated data for budgets + updated_budgets = await prisma_client.get_data( + table_name="budget", + query_type="find_all", + reset_at=datetime.now() + timedelta(days=31), + ) + + # Assertions for end users + user_a = [user for user in updated_end_users if user.user_id == id_end_user_a][0] + user_b = [user for user in updated_end_users if user.user_id == id_end_user_b][0] + user_c = [user for user in updated_end_users if user.user_id == id_end_user_c][0] + + assert user_a.spend == 0, "Spend for end_user_a was not reset to 0" + assert user_b.spend == 0, "Spend for end_user_b was not reset to 0" + assert user_c.spend > 0, "Spend for end_user_c should not be reset" + + # Assertions for budgets + budget_x = [ + budget for budget in updated_budgets if budget.budget_id == id_budget_x + ][0] + budget_y = [ + budget for budget in updated_budgets if budget.budget_id == id_budget_y + ][0] + + assert ( + budget_x.budget_reset_at > initial_budget_x_reset_at_date + ), "Budget X budget_reset_at was not updated" + assert ( + budget_y.budget_reset_at == initial_budget_y_reset_at_date + ), "Budget Y budget_reset_at should remain unchanged" diff --git a/tests/test_budget_management.py b/tests/test_budget_management.py new file mode 100644 index 0000000000..8175a9b1df --- /dev/null +++ b/tests/test_budget_management.py @@ -0,0 +1,90 @@ +# What is this? +## Unit tests for the /budget/* endpoints +import uuid +from datetime import datetime, timedelta + +import aiohttp +import pytest +import pytest_asyncio + + +async def delete_budget(session, budget_id): + url = "http://0.0.0.0:4000/budget/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"id": budget_id} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted Budget {budget_id}") + + +async def create_budget(session, data): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + budget_id = response_data["budget_id"] + print(f"Created Budget {budget_id}") + return response_data + + +@pytest_asyncio.fixture +async def budget_setup(): + """ + Fixture to create a budget for testing and clean it up afterward. + + This fixture performs the following steps: + 1. Opens an aiohttp ClientSession. + 2. Generates a random budget_id and defines the budget data (duration: 1 day, max_budget: 0.02). + 3. Calls create_budget to create the budget. + 4. Yields the budget_response (a dict) for use in the test. + 5. After the test completes, deletes the created budget by calling delete_budget. + + Returns: + dict: The JSON response from create_budget, which includes the created budget's data. + """ + + async with aiohttp.ClientSession() as session: + # Generate a unique budget_id and define the budget data. + budget_id = f"budget-{uuid.uuid4()}" + data = {"budget_id": budget_id, "budget_duration": "1d", "max_budget": 0.02} + budget_response = await create_budget(session, data) + + # Yield the response so the test can use it. + yield budget_response + + # After the test, delete the created budget to clean up. + await delete_budget(session, budget_id) + + +@pytest.mark.asyncio +async def test_create_budget_with_duration(budget_setup): + """ + Test creating a budget with a specified duration and verify that the 'budget_reset_at' + timestamp is correctly calculated as 'created_at' plus the budget duration (one day). + + This test uses the budget_setup fixture, which handles both the creation and cleanup of the budget. + """ + + # Verify that the response includes a 'budget_reset_at' timestamp. + assert ( + budget_setup["budget_reset_at"] is not None + ), "The budget_reset_at field should not be None" + + # Calculate the expected reset time: created_at + 1 day. + expected_reset_at_date = datetime.fromisoformat( + budget_setup["created_at"] + ) + timedelta(days=1) + + # Allow for a small tolerance in seconds for the timestamp calculation. + tolerance_seconds = 3 + actual_reset_at_date = datetime.fromisoformat(budget_setup["budget_reset_at"]) + time_difference = abs( + (actual_reset_at_date - expected_reset_at_date).total_seconds() + ) + + assert time_difference <= tolerance_seconds, ( + f"Expected budget_reset_at to be within {tolerance_seconds} seconds of {expected_reset_at_date}, " + f"but the difference was {time_difference} seconds." + ) From adfe462f3e316061bd7fa643d6558ee4f2592170 Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Wed, 12 Feb 2025 15:15:14 +0100 Subject: [PATCH 03/10] feat: adding fallback for already running applications with created budgets --- litellm/proxy/utils.py | 23 ++++++++++++++++++++-- tests/proxy_unit_tests/test_proxy_utils.py | 5 +++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index afd211aa95..7d203a3607 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1489,7 +1489,15 @@ class PrismaClient: if query_type == "find_all": response = await self.db.litellm_budgettable.find_many( where={ # type:ignore - "budget_reset_at": {"lt": reset_at} + "OR": [ + { + "AND": [ + {"budget_reset_at": None}, + {"NOT": {"budget_duration": None}}, + ] + }, + {"budget_reset_at": {"lt": reset_at}}, + ] } ) return response @@ -2499,8 +2507,19 @@ async def reset_budget(prisma_client: PrismaClient): budget_id_list_to_reset_enduser = [] if budgets_to_reset is not None and len(budgets_to_reset) > 0: for budget in budgets_to_reset: - budget_id_list_to_reset_enduser.append(budget.budget_id) duration_s = duration_in_seconds(duration=budget.budget_duration) + + # Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account + if ( + budget.budget_reset_at is None + and budget.created_at + timedelta(seconds=duration_s) > now + ): + budget.budget_reset_at = budget.created_at + timedelta( + seconds=duration_s + ) + continue + + budget_id_list_to_reset_enduser.append(budget.budget_id) budget.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( query_type="update_many", diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 9d756b1d66..2407d80a4f 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -3,7 +3,7 @@ import json import os import sys import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -1639,6 +1639,7 @@ def test_provider_specific_header(): }, } + async def create_budget(session, data): url = "http://0.0.0.0:4000/budget/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -1831,7 +1832,7 @@ async def test_reset_budget_for_endusers(prisma_client, budget_and_enduser_setup updated_budgets = await prisma_client.get_data( table_name="budget", query_type="find_all", - reset_at=datetime.now() + timedelta(days=31), + reset_at=datetime.now(timezone.utc) + timedelta(days=31), ) # Assertions for end users From 93ce06af216d0ca47b3a9aadab8691153e578f00 Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Wed, 19 Mar 2025 13:18:02 +0100 Subject: [PATCH 04/10] adding budget_reset_at to LiteLLM_BudgetTable --- litellm/proxy/_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d87df955fd..32364efc97 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1084,6 +1084,7 @@ class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase): rpm_limit: Optional[int] = None model_max_budget: Optional[dict] = None budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None model_config = ConfigDict(protected_namespaces=()) From 72855349034dbbf62d9c9cc7fd0520bfcf531f43 Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Wed, 19 Mar 2025 17:24:12 +0100 Subject: [PATCH 05/10] add reset enduser spend logic to new reset_budget function --- .../proxy/common_utils/reset_budget_job.py | 82 ++++++++++++- litellm/proxy/utils.py | 111 ------------------ 2 files changed, 81 insertions(+), 112 deletions(-) diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index 1d50002f5c..cf3c73ad9e 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -1,7 +1,7 @@ import asyncio import json import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import List, Literal, Optional, Union from litellm._logging import verbose_proxy_logger @@ -10,6 +10,7 @@ from litellm.proxy._types import ( LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken, + LiteLLM_EndUserTable ) from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.types.services import ServiceTypes @@ -44,6 +45,85 @@ class ResetBudgetJob: ## Reset Team Budget await self.reset_budget_for_litellm_teams() + ### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ### + await self.reset_budget_for_litellm_endusers() + + + async def reset_budget_for_litellm_endusers(self): + """ + Resets the budget for all LiteLLM End-Users (Customers) if their budget has expired + The corresponding Budget duration is also updated. + """ + now = datetime.now(timezone.utc) + + start_time = time.time() + endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None + # users_to_reset: Optional[List[LiteLLM_UserTable]] = None + try: + budgets_to_reset = await self.prisma_client.get_data( + table_name="budget", query_type="find_all", reset_at=now + ) + budget_id_list_to_reset_enduser: List[str] = [] + if budgets_to_reset is not None and len(budgets_to_reset) > 0: + for budget in budgets_to_reset: + if budget.budget_duration is not None: + duration_s = duration_in_seconds(duration=budget.budget_duration) + + # Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account + if ( + budget.budget_reset_at is None + and budget.created_at + timedelta(seconds=duration_s) > now + ): + budget.budget_reset_at = budget.created_at + timedelta( + seconds=duration_s + ) + continue + + budget_id_list_to_reset_enduser.append(budget.budget_id) + budget.budget_reset_at = now + timedelta(seconds=duration_s) + await self.prisma_client.update_data( + query_type="update_many", + data_list=budgets_to_reset, + table_name="budget", + ) + + endusers_to_reset = await self.prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=budget_id_list_to_reset_enduser, + ) + + if endusers_to_reset is not None and len(endusers_to_reset) > 0: + for enduser in endusers_to_reset: + enduser.spend = 0.0 + await self.prisma_client.update_data( + query_type="update_many", + data_list=endusers_to_reset, + table_name="enduser", + ) + + + except Exception as e: + end_time = time.time() + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_failure_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + error=e, + call_type="reset_budget_endusers", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_users_found": len(endusers_to_reset) if endusers_to_reset else 0, + "endusers_found": json.dumps( + endusers_to_reset, indent=4, default=str + ), + }, + ) + ) + verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e) + + async def reset_budget_for_litellm_keys(self): """ Resets the budget for all the litellm keys diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a9977025ec..074d216988 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -2447,117 +2447,6 @@ def _hash_token_if_needed(token: str) -> str: return token -async def reset_budget(prisma_client: PrismaClient): - """ - Gets all the non-expired keys for a db, which need spend to be reset - - Resets their spend - - Updates db - """ - if prisma_client is not None: - ### RESET KEY BUDGET ### - now = datetime.utcnow() - keys_to_reset = await prisma_client.get_data( - table_name="key", query_type="find_all", expires=now, reset_at=now - ) - - if keys_to_reset is not None and len(keys_to_reset) > 0: - for key in keys_to_reset: - key.spend = 0.0 - duration_s = duration_in_seconds(duration=key.budget_duration) - key.budget_reset_at = now + timedelta(seconds=duration_s) - - await prisma_client.update_data( - query_type="update_many", data_list=keys_to_reset, table_name="key" - ) - - ### RESET USER BUDGET ### - now = datetime.utcnow() - users_to_reset = await prisma_client.get_data( - table_name="user", query_type="find_all", reset_at=now - ) - - if users_to_reset is not None and len(users_to_reset) > 0: - for user in users_to_reset: - user.spend = 0.0 - duration_s = duration_in_seconds(duration=user.budget_duration) - user.budget_reset_at = now + timedelta(seconds=duration_s) - - await prisma_client.update_data( - query_type="update_many", data_list=users_to_reset, table_name="user" - ) - - ## Reset End-User (Customer) Spend and corresponding Budget duration - now = datetime.now(timezone.utc) - - budgets_to_reset = await prisma_client.get_data( - table_name="budget", query_type="find_all", reset_at=now - ) - budget_id_list_to_reset_enduser = [] - if budgets_to_reset is not None and len(budgets_to_reset) > 0: - for budget in budgets_to_reset: - duration_s = duration_in_seconds(duration=budget.budget_duration) - - # Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account - if ( - budget.budget_reset_at is None - and budget.created_at + timedelta(seconds=duration_s) > now - ): - budget.budget_reset_at = budget.created_at + timedelta( - seconds=duration_s - ) - continue - - budget_id_list_to_reset_enduser.append(budget.budget_id) - budget.budget_reset_at = now + timedelta(seconds=duration_s) - await prisma_client.update_data( - query_type="update_many", - data_list=budgets_to_reset, - table_name="budget", - ) - - endusers_to_reset = await prisma_client.get_data( - table_name="enduser", - query_type="find_all", - budget_id_list=budget_id_list_to_reset_enduser, - ) - - if endusers_to_reset is not None and len(endusers_to_reset) > 0: - for enduser in endusers_to_reset: - enduser.spend = 0.0 - await prisma_client.update_data( - query_type="update_many", - data_list=endusers_to_reset, - table_name="enduser", - ) - - ## Reset Team Budget - now = datetime.utcnow() - teams_to_reset = await prisma_client.get_data( - table_name="team", - query_type="find_all", - reset_at=now, - ) - - if teams_to_reset is not None and len(teams_to_reset) > 0: - team_reset_requests = [] - for team in teams_to_reset: - duration_s = duration_in_seconds(duration=team.budget_duration) - reset_team_budget_request = ResetTeamBudgetRequest( - team_id=team.team_id, - spend=0.0, - budget_reset_at=now + timedelta(seconds=duration_s), - updated_at=now, - ) - team_reset_requests.append(reset_team_budget_request) - await prisma_client.update_data( - query_type="update_many", - data_list=team_reset_requests, - table_name="team", - ) - - class ProxyUpdateSpend: @staticmethod async def update_end_user_spend( From 9bda42dc035d24065ee83336fc969b6a56d1b4bd Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Fri, 21 Mar 2025 09:30:08 +0100 Subject: [PATCH 06/10] adaptions to new code structure + tests --- litellm/proxy/_types.py | 4 + .../proxy/common_utils/reset_budget_job.py | 139 ++++++-- .../test_integration_reset_enduser_spend.py | 269 +++++++++++++++ .../common_utils/test_reset_budget_job.py | 73 ++++- .../test_proxy_budget_reset.py | 235 ++++++++++++- tests/proxy_unit_tests/test_proxy_utils.py | 310 ++---------------- 6 files changed, 715 insertions(+), 315 deletions(-) create mode 100644 tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6a67eb3907..9d38cf239a 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1151,6 +1151,10 @@ class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase): model_config = ConfigDict(protected_namespaces=()) +class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable): + """Represents all params for a LiteLLM_BudgetTable record""" + budget_id: str + created_at: datetime class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): """ diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index cf3c73ad9e..c40631f002 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -10,7 +10,8 @@ from litellm.proxy._types import ( LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken, - LiteLLM_EndUserTable + LiteLLM_EndUserTable, + LiteLLM_BudgetTableFull ) from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.types.services import ServiceTypes @@ -55,54 +56,94 @@ class ResetBudgetJob: The corresponding Budget duration is also updated. """ now = datetime.now(timezone.utc) - start_time = time.time() endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None - # users_to_reset: Optional[List[LiteLLM_UserTable]] = None + budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None + updated_endusers: List[LiteLLM_EndUserTable] = [] + failed_endusers = [] try: budgets_to_reset = await self.prisma_client.get_data( table_name="budget", query_type="find_all", reset_at=now ) - budget_id_list_to_reset_enduser: List[str] = [] + if budgets_to_reset is not None and len(budgets_to_reset) > 0: for budget in budgets_to_reset: - if budget.budget_duration is not None: - duration_s = duration_in_seconds(duration=budget.budget_duration) - - # Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account - if ( - budget.budget_reset_at is None - and budget.created_at + timedelta(seconds=duration_s) > now - ): - budget.budget_reset_at = budget.created_at + timedelta( - seconds=duration_s - ) - continue - - budget_id_list_to_reset_enduser.append(budget.budget_id) - budget.budget_reset_at = now + timedelta(seconds=duration_s) + budget = await ResetBudgetJob._reset_budget_reset_at_date(budget, now) await self.prisma_client.update_data( query_type="update_many", data_list=budgets_to_reset, table_name="budget", ) - endusers_to_reset = await self.prisma_client.get_data( - table_name="enduser", - query_type="find_all", - budget_id_list=budget_id_list_to_reset_enduser, - ) + endusers_to_reset = await self.prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[budget.budget_id for budget in budgets_to_reset] + ) if endusers_to_reset is not None and len(endusers_to_reset) > 0: + for enduser in endusers_to_reset: - enduser.spend = 0.0 + try: + updated_enduser = await ResetBudgetJob._reset_budget_for_enduser( + enduser=enduser + ) + if updated_enduser is not None: + updated_endusers.append(updated_enduser) + else: + failed_endusers.append( + { + "enduser": enduser, + "error": "Returned None without exception", + } + ) + except Exception as e: + failed_endusers.append({"enduser": enduser, "error": str(e)}) + verbose_proxy_logger.exception( + "Failed to reset budget for enduser: %s", enduser + ) + + verbose_proxy_logger.debug( + "Updated users %s", json.dumps(updated_endusers, indent=4, default=str) + ) + await self.prisma_client.update_data( query_type="update_many", - data_list=endusers_to_reset, + data_list=updated_endusers, table_name="enduser", ) + end_time = time.time() + if len(failed_endusers) > 0: # If any endusers failed to reset + raise Exception( + f"Failed to reset {len(failed_endusers)} endusers: {json.dumps(failed_endusers, default=str)}" + ) + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + call_type="reset_budget_endusers", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_budgets_found": len(budgets_to_reset) if budgets_to_reset else 0, + "budgets_found": json.dumps( + budgets_to_reset, indent=4, default=str + ), + "num_endusers_found": len(endusers_to_reset) if endusers_to_reset else 0, + "endusers_found": json.dumps( + endusers_to_reset, indent=4, default=str + ), + "num_endusers_updated": len(updated_endusers), + "endusers_updated": json.dumps( + updated_endusers, indent=4, default=str + ), + "num_endusers_failed": len(failed_endusers), + "endusers_failed": json.dumps(failed_endusers, indent=4, default=str), + }, + ) + ) except Exception as e: end_time = time.time() asyncio.create_task( @@ -114,10 +155,14 @@ class ResetBudgetJob: start_time=start_time, end_time=end_time, event_metadata={ - "num_users_found": len(endusers_to_reset) if endusers_to_reset else 0, + "num_budgets_found": len(budgets_to_reset) if budgets_to_reset else 0, + "budgets_found": json.dumps( + budgets_to_reset, indent=4, default=str + ), + "num_endusers_found": len(endusers_to_reset) if endusers_to_reset else 0, "endusers_found": json.dumps( endusers_to_reset, indent=4, default=str - ), + ) }, ) ) @@ -435,6 +480,44 @@ class ResetBudgetJob: ) return user + @staticmethod + async def _reset_budget_for_enduser( + enduser: LiteLLM_EndUserTable + ) -> Optional[LiteLLM_EndUserTable]: + try: + enduser.spend = 0.0 + except Exception as e: + verbose_proxy_logger.exception( + "Error resetting budget for enduser: %s. Item: %s", e, enduser + ) + raise e + return enduser + + @staticmethod + async def _reset_budget_reset_at_date( + budget: LiteLLM_BudgetTableFull, current_time: datetime + ) -> Optional[LiteLLM_BudgetTableFull]: + try: + if budget.budget_duration is not None: + duration_s = duration_in_seconds(duration=budget.budget_duration) + + # Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account + if ( + budget.budget_reset_at is None + and budget.created_at + timedelta(seconds=duration_s) > current_time + ): + budget.budget_reset_at = budget.created_at + timedelta( + seconds=duration_s + ) + else: + budget.budget_reset_at = current_time + timedelta(seconds=duration_s) + except Exception as e: + verbose_proxy_logger.exception( + "Error resetting budget_reset_at for budget: %s. Item: %s", e, budget + ) + raise e + return budget + @staticmethod async def _reset_budget_for_key( key: LiteLLM_VerificationToken, current_time: datetime diff --git a/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py b/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py new file mode 100644 index 0000000000..8c242e7e5c --- /dev/null +++ b/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py @@ -0,0 +1,269 @@ +import asyncio +import os +import sys +import time +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, Mock, patch +from litellm.caching.caching import DualCache, RedisCache + +import aiohttp +import pytest +import pytest_asyncio +from fastapi import Request +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob +from litellm.proxy.utils import PrismaClient, ProxyLogging + + +async def create_budget(session, data): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + budget_id = response_data["budget_id"] + print(f"Created Budget {budget_id}") + return response_data + + +async def create_end_user(prisma_client, session, user_id, budget_id, spend=None): + url = "http://0.0.0.0:4000/end_user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "user_id": user_id, + "budget_id": budget_id, + } + + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + response_data = await response.json() + end_user_id = response_data["user_id"] + print(f"Created End User {end_user_id}") + + if spend is not None: + end_users = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[budget_id], + ) + end_user = [user for user in end_users if user.user_id == user_id][0] + end_user.spend = spend + await prisma_client.update_data( + query_type="update_many", + data_list=[end_user], + table_name="enduser", + ) + + return response_data + + +async def delete_budget(session, budget_id): + url = "http://0.0.0.0:4000/budget/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"id": budget_id} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted Budget {budget_id}") + + +async def delete_end_user(session, user_id): + url = "http://0.0.0.0:4000/end_user/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"user_ids": [user_id]} + async with session.post(url, headers=headers, json=data) as response: + assert response.status == 200 + print(f"Deleted End User {user_id}") + + +@pytest.fixture +def prisma_client(): + proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + return prisma_client + + +class MockProxyLogging: + class MockServiceLogging: + async def async_service_success_hook(self, **kwargs): + pass + + async def async_service_failure_hook(self, **kwargs): + pass + + def __init__(self): + self.service_logging_obj = self.MockServiceLogging() + + +@pytest.fixture +def mock_proxy_logging(): + return MockProxyLogging() + + +@pytest.fixture +def reset_budget_job(prisma_client, mock_proxy_logging): + return ResetBudgetJob( + proxy_logging_obj=mock_proxy_logging, prisma_client=prisma_client + ) + + +@pytest_asyncio.fixture +async def budget_and_enduser_setup(prisma_client): + """ + Fixture to set up budgets and end users for testing and clean them up afterward. + + This fixture performs the following: + - Creates two budgets: + * Budget X with a short duration ("5s"). + * Budget Y with a long duration ("30d"). + - Stores the initial 'budget_reset_at' timestamps for both budgets. + - Creates three end users: + * End Users A and B are associated with Budget X and are given initial spend values. + * End User C is associated with Budget Y with an initial spend. + - After the test (after the yield), the created end users and budgets are deleted. + """ + await prisma_client.connect() + + async with aiohttp.ClientSession() as session: + # Create budgets + id_budget_x = f"budget-{uuid.uuid4()}" + data_budget_x = { + "budget_id": id_budget_x, + "budget_duration": "5s", + "max_budget": 2, + } + id_budget_y = f"budget-{uuid.uuid4()}" + data_budget_y = { + "budget_id": id_budget_y, + "budget_duration": "30d", + "max_budget": 1, + } + response_budget_x = await create_budget(session, data_budget_x) + initial_budget_x_reset_at_date = datetime.fromisoformat( + response_budget_x["budget_reset_at"] + ) + response_budget_y = await create_budget(session, data_budget_y) + initial_budget_y_reset_at_date = datetime.fromisoformat( + response_budget_y["budget_reset_at"] + ) + + # Create end users + id_end_user_a = f"test-user-{uuid.uuid4()}" + id_end_user_b = f"test-user-{uuid.uuid4()}" + id_end_user_c = f"test-user-{uuid.uuid4()}" + await create_end_user( + prisma_client, session, id_end_user_a, id_budget_x, spend=0.16 + ) + await create_end_user( + prisma_client, session, id_end_user_b, id_budget_x, spend=0.04 + ) + await create_end_user( + prisma_client, session, id_end_user_c, id_budget_y, spend=0.2 + ) + + # Bundle data needed for the test + setup_data = { + "budgets": { + "id_budget_x": id_budget_x, + "id_budget_y": id_budget_y, + "initial_budget_x_reset_at_date": initial_budget_x_reset_at_date, + "initial_budget_y_reset_at_date": initial_budget_y_reset_at_date, + }, + "end_users": { + "id_end_user_a": id_end_user_a, + "id_end_user_b": id_end_user_b, + "id_end_user_c": id_end_user_c, + }, + } + + # Provide the setup data to the test + yield setup_data + + # Clean-up: Delete the created test data + await delete_end_user(session, id_end_user_a) + await delete_end_user(session, id_end_user_b) + await delete_end_user(session, id_end_user_c) + await delete_budget(session, id_budget_x) + await delete_budget(session, id_budget_y) + + +@pytest.mark.asyncio +async def test_reset_budget_for_endusers( + reset_budget_job, prisma_client, budget_and_enduser_setup +): + """ + Test the part "Reset End-User (Customer) Spend and corresponding Budget duration" in reset_budget function. + + This test uses the budget_and_enduser_setup fixture to create budgets and end users, + waits for the short-duration budget to expire, calls reset_budget, and verifies that: + - End users associated with the short-duration budget X have their spend reset to 0. + - The budget_reset_at timestamp for the short-duration budget X is updated, + while the long-duration budget Y remains unchanged. + """ + + # Unpack the required data from the fixture + budgets = budget_and_enduser_setup["budgets"] + end_users = budget_and_enduser_setup["end_users"] + + id_budget_x = budgets["id_budget_x"] + id_budget_y = budgets["id_budget_y"] + initial_budget_x_reset_at_date = budgets["initial_budget_x_reset_at_date"] + initial_budget_y_reset_at_date = budgets["initial_budget_y_reset_at_date"] + + id_end_user_a = end_users["id_end_user_a"] + id_end_user_b = end_users["id_end_user_b"] + id_end_user_c = end_users["id_end_user_c"] + + # Wait for Budget X to expire (short duration "5s" plus a small buffer) + await asyncio.sleep(6) + + # Call the reset_budget function: + # It should reset the spend values for end users associated with Budget X. + await reset_budget_job.reset_budget_for_litellm_endusers() + + # Retrieve updated data for end users + updated_end_users = await prisma_client.get_data( + table_name="enduser", + query_type="find_all", + budget_id_list=[id_budget_x, id_budget_y], + ) + # Retrieve updated data for budgets + updated_budgets = await prisma_client.get_data( + table_name="budget", + query_type="find_all", + reset_at=datetime.now(timezone.utc) + timedelta(days=31), + ) + + # Assertions for end users + user_a = [user for user in updated_end_users if user.user_id == id_end_user_a][0] + user_b = [user for user in updated_end_users if user.user_id == id_end_user_b][0] + user_c = [user for user in updated_end_users if user.user_id == id_end_user_c][0] + + assert user_a.spend == 0, "Spend for end_user_a was not reset to 0" + assert user_b.spend == 0, "Spend for end_user_b was not reset to 0" + assert user_c.spend > 0, "Spend for end_user_c should not be reset" + + # Assertions for budgets + budget_x = [ + budget for budget in updated_budgets if budget.budget_id == id_budget_x + ][0] + budget_y = [ + budget for budget in updated_budgets if budget.budget_id == id_budget_y + ][0] + + assert ( + budget_x.budget_reset_at > initial_budget_x_reset_at_date + ), "Budget X budget_reset_at was not updated" + assert ( + budget_y.budget_reset_at == initial_budget_y_reset_at_date + ), "Budget Y budget_reset_at should remain unchanged" diff --git a/tests/litellm/proxy/common_utils/test_reset_budget_job.py b/tests/litellm/proxy/common_utils/test_reset_budget_job.py index bb4af00d78..d1e359554f 100644 --- a/tests/litellm/proxy/common_utils/test_reset_budget_job.py +++ b/tests/litellm/proxy/common_utils/test_reset_budget_job.py @@ -3,7 +3,7 @@ import json import os import sys import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest from fastapi.testclient import TestClient @@ -20,8 +20,8 @@ from litellm.proxy.utils import ProxyLogging # Mock classes for testing class MockPrismaClient: def __init__(self): - self.data = {"key": [], "user": [], "team": []} - self.updated_data = {"key": [], "user": [], "team": []} + self.data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []} + self.updated_data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []} async def get_data(self, table_name, query_type, **kwargs): return self.data.get(table_name, []) @@ -145,9 +145,48 @@ def test_reset_budget_for_team(reset_budget_job, mock_prisma_client): assert updated_team.budget_reset_at > now + +def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client): + # Setup test data + now = datetime.now(timezone.utc) + test_budget = type( + "LiteLLM_BudgetTable", + (), + { + "max_budget": 500.0, + "budget_duration": "1d", + "budget_reset_at": now, + "budget_id": "test-budget-1", + }, + ) + + test_enduser = type( + "LiteLLM_EndUserTable", + (), + { + "spend": 20.0, + "litellm_budget_table": test_budget, + "user_id": "test-enduser-1", + }, + ) + + mock_prisma_client.data["budget"] = [test_budget] + mock_prisma_client.data["enduser"] = [test_enduser] + + # Run the test + asyncio.run(reset_budget_job.reset_budget_for_litellm_endusers()) + + # Verify results + assert len(mock_prisma_client.updated_data["enduser"]) == 1 + assert len(mock_prisma_client.updated_data["budget"]) == 1 + updated_enduser = mock_prisma_client.updated_data["enduser"][0] + updated_budget = mock_prisma_client.updated_data["budget"][0] + assert updated_enduser.spend == 0.0 + assert updated_budget.budget_reset_at > now + def test_reset_budget_all(reset_budget_job, mock_prisma_client): # Setup test data - now = datetime.utcnow() + now = datetime.now(timezone.utc) # Create test objects for all three types test_key = type( @@ -183,9 +222,32 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client): }, ) + test_budget = type( + "LiteLLM_BudgetTable", + (), + { + "max_budget": 500.0, + "budget_duration": "1d", + "budget_reset_at": now, + "budget_id": "test-budget-1", + }, + ) + + test_enduser = type( + "LiteLLM_EndUserTable", + (), + { + "spend": 20.0, + "litellm_budget_table": test_budget, + "user_id": "test-enduser-1", + }, + ) + mock_prisma_client.data["key"] = [test_key] mock_prisma_client.data["user"] = [test_user] mock_prisma_client.data["team"] = [test_team] + mock_prisma_client.data["budget"] = [test_budget] + mock_prisma_client.data["enduser"] = [test_enduser] # Run the test asyncio.run(reset_budget_job.reset_budget()) @@ -194,8 +256,11 @@ def test_reset_budget_all(reset_budget_job, mock_prisma_client): assert len(mock_prisma_client.updated_data["key"]) == 1 assert len(mock_prisma_client.updated_data["user"]) == 1 assert len(mock_prisma_client.updated_data["team"]) == 1 + assert len(mock_prisma_client.updated_data["enduser"]) == 1 + assert len(mock_prisma_client.updated_data["budget"]) == 1 # Check that all spends were reset to 0 assert mock_prisma_client.updated_data["key"][0].spend == 0.0 assert mock_prisma_client.updated_data["user"][0].spend == 0.0 assert mock_prisma_client.updated_data["team"][0].spend == 0.0 + assert mock_prisma_client.updated_data["enduser"][0].spend == 0.0 diff --git a/tests/litellm_utils_tests/test_proxy_budget_reset.py b/tests/litellm_utils_tests/test_proxy_budget_reset.py index 1fbe493d8d..4fbde76f9a 100644 --- a/tests/litellm_utils_tests/test_proxy_budget_reset.py +++ b/tests/litellm_utils_tests/test_proxy_budget_reset.py @@ -3,7 +3,7 @@ import sys import time import traceback import uuid -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,6 +15,8 @@ load_dotenv() import os import tempfile from uuid import uuid4 +from litellm.proxy._types import LiteLLM_BudgetTableFull +from litellm.litellm_core_utils.duration_parser import duration_in_seconds sys.path.insert( 0, os.path.abspath("../..") @@ -180,6 +182,80 @@ async def test_reset_budget_users_partial_failure(): ) + +@pytest.mark.asyncio +async def test_reset_budget_endusers_partial_failure(): + """ + Test that if one enduser fails to reset, the reset loop still processes the other endusers. + We simulate six endsers where the first fails and the others are updated. + """ + user1 = { + "user_id": "user1", + "spend": 20.0, + "budget_id": "budget1", + } # Will trigger simulated failure + user2 = {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"} # Should be updated + user3 = {"user_id": "user3", "spend": 30.0, "budget_id": "budget1"} # Should be updated + user4 = {"user_id": "user4", "spend": 35.0, "budget_id": "budget1"} # Should be updated + user5 = {"user_id": "user5", "spend": 40.0, "budget_id": "budget1"} # Should be updated + user6 = {"user_id": "user6", "spend": 45.0, "budget_id": "budget1"} # Should be updated + + budget1 = LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + + prisma_client = MagicMock() + + async def get_data_mock(table_name, *args, **kwargs): + if table_name == "budget": + return [budget1] + elif table_name == "enduser": + return [user1, user2, user3, user4, user5, user6] + return [] + + + prisma_client.get_data = AsyncMock() + prisma_client.get_data.side_effect = get_data_mock + + prisma_client.update_data = AsyncMock() + + proxy_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock() + proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock() + + job = ResetBudgetJob(proxy_logging_obj, prisma_client) + + async def fake_reset_enduser(enduser): + if enduser["user_id"] == "user1": + raise Exception("Simulated failure for user1") + enduser["spend"] = 0.0 + return enduser + + with patch.object( + ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser + ) as mock_reset_enduser: + await job.reset_budget_for_litellm_endusers() + await asyncio.sleep(0.1) + + assert mock_reset_enduser.call_count == 6 + assert prisma_client.update_data.await_count == 2 + update_call = prisma_client.update_data.call_args + assert update_call.kwargs.get("table_name") == "enduser" + updated_users = update_call.kwargs.get("data_list", []) + assert len(updated_users) == 5 + assert updated_users[0]["user_id"] == "user2" + assert updated_users[1]["user_id"] == "user3" + assert updated_users[2]["user_id"] == "user4" + assert updated_users[3]["user_id"] == "user5" + assert updated_users[4]["user_id"] == "user6" + + failure_hook_calls = ( + proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args_list + ) + assert any( + call.kwargs.get("call_type") == "reset_budget_endusers" + for call in failure_hook_calls + ) + @pytest.mark.asyncio async def test_reset_budget_teams_partial_failure(): """ @@ -263,6 +339,8 @@ async def test_reset_budget_continues_other_categories_on_failure(): user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180} team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180} + enduser1 = {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"} + budget1 = LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) prisma_client = MagicMock() @@ -273,6 +351,10 @@ async def test_reset_budget_continues_other_categories_on_failure(): return [user1, user2] elif table_name == "team": return [team1, team2] + elif table_name == "budget": + return [budget1] + elif table_name == "enduser": + return [enduser1] return [] prisma_client.get_data = AsyncMock(side_effect=fake_get_data) @@ -307,14 +389,20 @@ async def test_reset_budget_continues_other_categories_on_failure(): current_time + timedelta(seconds=team["budget_duration"]) ).isoformat() return team - + + async def fake_reset_enduser(enduser): + enduser["spend"] = 0.0 + return enduser + with patch.object( ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key ) as mock_reset_key, patch.object( ResetBudgetJob, "_reset_budget_for_user", side_effect=fake_reset_user ) as mock_reset_user, patch.object( ResetBudgetJob, "_reset_budget_for_team", side_effect=fake_reset_team - ) as mock_reset_team: + ) as mock_reset_team, patch.object( + ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser + ) as mock_reset_enduser: # Call the overall reset_budget method. await job.reset_budget() await asyncio.sleep(0.1) @@ -323,10 +411,10 @@ async def test_reset_budget_continues_other_categories_on_failure(): called_tables = { call.kwargs.get("table_name") for call in prisma_client.get_data.await_args_list } - assert called_tables == {"key", "user", "team"} + assert called_tables == {"key", "user", "team", "budget", "enduser"} - # Verify that update_data was called three times (one per category) - assert prisma_client.update_data.await_count == 3 + # Verify that update_data was called three times (one per category, enduser update includes two) + assert prisma_client.update_data.await_count == 5 calls = prisma_client.update_data.await_args_list # Check keys update: both keys succeed. @@ -346,9 +434,13 @@ async def test_reset_budget_continues_other_categories_on_failure(): assert teams_call.kwargs.get("table_name") == "team" assert len(teams_call.kwargs.get("data_list", [])) == 2 + # Check enduser update: enduser succeed. + enduser_call = calls[4] + assert enduser_call.kwargs.get("table_name") == "enduser" + assert len(enduser_call.kwargs.get("data_list", [])) == 1 # --------------------------------------------------------------------------- -# Additional tests for service logger behavior (keys, users, teams) +# Additional tests for service logger behavior (keys, users, teams, endusers) # --------------------------------------------------------------------------- @@ -685,3 +777,132 @@ async def test_service_logger_teams_failure(): teams_found_str = event_metadata.get("teams_found", "") assert "team1" in teams_found_str proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called() + + + +@pytest.mark.asyncio +async def test_service_logger_endusers_success(): + """ + Test that when resetting endusers succeeds the service logger success hook is called with + the correct metadata and no exception is logged. + """ + endusers = [ + {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, + {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"} + ] + budgets = [ + LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + ] + + async def fake_get_data(*, table_name, query_type, **kwargs): + if table_name == "budget": + return budgets + elif table_name == "enduser": + return endusers + return [] + + prisma_client = MagicMock() + prisma_client.get_data = AsyncMock(side_effect=fake_get_data) + prisma_client.update_data = AsyncMock() + + proxy_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock() + proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock() + + job = ResetBudgetJob(proxy_logging_obj, prisma_client) + + async def fake_reset_enduser(enduser): + enduser["spend"] = 0.0 + return enduser + + with patch.object( + ResetBudgetJob, + "_reset_budget_for_enduser", + side_effect=fake_reset_enduser, + ): + with patch( + "litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception" + ) as mock_verbose_exc: + await job.reset_budget_for_litellm_endusers() + await asyncio.sleep(0.1) + mock_verbose_exc.assert_not_called() + + proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() + args, kwargs = ( + proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args + ) + event_metadata = kwargs.get("event_metadata", {}) + assert event_metadata.get("num_budgets_found") == len(budgets) + assert event_metadata.get("num_endusers_found") == len(endusers) + assert event_metadata.get("num_endusers_updated") == len(endusers) + assert event_metadata.get("num_endusers_failed") == 0 + proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_not_called() + + +@pytest.mark.asyncio +async def test_service_logger_users_failure(): + """ + Test that a failure during enduser reset calls the failure hook with appropriate metadata, + logs the exception, and does not call the success hook. + """ + endusers = [ + {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, + {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"} + ] + budgets = [ + LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + ] + + async def fake_get_data(*, table_name, query_type, **kwargs): + if table_name == "budget": + return budgets + elif table_name == "enduser": + return endusers + return [] + + prisma_client = MagicMock() + prisma_client.get_data = AsyncMock(side_effect=fake_get_data) + prisma_client.update_data = AsyncMock() + + proxy_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj = MagicMock() + proxy_logging_obj.service_logging_obj.async_service_success_hook = AsyncMock() + proxy_logging_obj.service_logging_obj.async_service_failure_hook = AsyncMock() + + job = ResetBudgetJob(proxy_logging_obj, prisma_client) + + async def fake_reset_enduser(enduser): + if enduser["user_id"] == "user1": + raise Exception("Simulated failure for user1") + enduser["spend"] = 0.0 + return enduser + + with patch.object( + ResetBudgetJob, + "_reset_budget_for_enduser", + side_effect=fake_reset_enduser, + ): + with patch( + "litellm.proxy.common_utils.reset_budget_job.verbose_proxy_logger.exception" + ) as mock_verbose_exc: + await job.reset_budget_for_litellm_endusers() + await asyncio.sleep(0.1) + # Verify exception logging + assert mock_verbose_exc.call_count >= 1 + # Verify exception was logged with correct message + assert any( + "Failed to reset budget for enduser" in str(call.args) + for call in mock_verbose_exc.call_args_list + ) + + proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() + args, kwargs = ( + proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args + ) + event_metadata = kwargs.get("event_metadata", {}) + assert event_metadata.get("num_budgets_found") == len(budgets) + assert event_metadata.get("num_endusers_found") == len(endusers) + endusers_found_str = event_metadata.get("endusers_found", "") + assert "user1" in endusers_found_str + proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called() diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 377d138339..f3f061c1e6 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1,38 +1,27 @@ import asyncio -import json import os import sys -import uuid -from datetime import datetime, timezone from typing import Any, Dict, Optional, List -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -import aiohttp +from unittest.mock import Mock +from litellm.proxy.utils import _get_redoc_url, _get_docs_url +import json import pytest -import pytest_asyncio from fastapi import Request +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path import litellm -from litellm.caching.caching import DualCache +from unittest.mock import MagicMock, patch, AsyncMock + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import is_request_body_safe from litellm.proxy.litellm_pre_call_utils import ( _get_dynamic_logging_metadata, add_litellm_data_to_request, ) -from litellm.proxy.utils import ( - PrismaClient, - ProxyLogging, - _get_docs_url, - _get_redoc_url, - reset_budget, -) from litellm.types.utils import SupportedCacheControls -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path - @pytest.fixture def mock_request(monkeypatch): @@ -499,9 +488,8 @@ def test_reading_openai_org_id_from_headers(): ) def test_add_litellm_data_for_backend_llm_call(headers, expected_data): import json - - from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + from litellm.proxy._types import UserAPIKeyAuth user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" @@ -521,8 +509,8 @@ def test_foward_litellm_user_info_to_backend_llm_call(): litellm.add_user_information_to_llm_headers = True - from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + from litellm.proxy._types import UserAPIKeyAuth user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" @@ -543,10 +531,10 @@ def test_foward_litellm_user_info_to_backend_llm_call(): def test_update_internal_user_params(): - from litellm.proxy._types import NewUserRequest from litellm.proxy.management_endpoints.internal_user_endpoints import ( _update_internal_new_user_params, ) + from litellm.proxy._types import NewUserRequest litellm.default_internal_user_params = { "max_budget": 100, @@ -570,9 +558,8 @@ def test_update_internal_user_params(): @pytest.mark.asyncio async def test_proxy_config_update_from_db(): - from pydantic import BaseModel - from litellm.proxy.proxy_server import ProxyConfig + from pydantic import BaseModel proxy_config = ProxyConfig() @@ -615,10 +602,10 @@ async def test_proxy_config_update_from_db(): def test_prepare_key_update_data(): - from litellm.proxy._types import UpdateKeyRequest from litellm.proxy.management_endpoints.key_management_endpoints import ( prepare_key_update_data, ) + from litellm.proxy._types import UpdateKeyRequest existing_key_row = MagicMock() data = UpdateKeyRequest(key="test_key", models=["gpt-4"], duration="120s") @@ -912,9 +899,8 @@ def test_enforced_params_check( def test_get_key_models(): - from collections import defaultdict - from litellm.proxy.auth.model_checks import get_key_models + from collections import defaultdict user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", @@ -937,9 +923,8 @@ def test_get_key_models(): def test_get_team_models(): - from collections import defaultdict - from litellm.proxy.auth.model_checks import get_team_models + from collections import defaultdict user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", @@ -1127,9 +1112,8 @@ def test_proxy_config_state_get_config_state_error(): """ Ensures that get_config_state does not raise an error when the config is not a valid dictionary """ - import threading - from litellm.proxy.proxy_server import ProxyConfig + import threading test_config = { "callback_list": [ @@ -1265,8 +1249,8 @@ def test_is_allowed_to_make_key_request(): def test_get_model_group_info(): - from litellm import Router from litellm.proxy.proxy_server import _get_model_group_info + from litellm import Router router = Router( model_list=[ @@ -1294,11 +1278,10 @@ def test_get_model_group_info(): assert len(model_list) == 1 -import asyncio -import json -from unittest.mock import AsyncMock, patch - import pytest +import asyncio +from unittest.mock import AsyncMock, patch +import json @pytest.fixture @@ -1357,6 +1340,7 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): "litellm.proxy.proxy_server.prisma_client", MockPrismaClientDB(mock_team_data, mock_key_data), ): + from litellm.proxy.management_endpoints.internal_user_endpoints import ( _get_user_info_for_proxy_admin, ) @@ -1370,12 +1354,10 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): def test_custom_openid_response(): - from litellm.caching import DualCache + from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor + from litellm.proxy.management_endpoints.ui_sso import JWTHandler from litellm.proxy._types import LiteLLM_JWTAuth - from litellm.proxy.management_endpoints.ui_sso import ( - JWTHandler, - generic_response_convertor, - ) + from litellm.caching import DualCache jwt_handler = JWTHandler() jwt_handler.update_environment( @@ -1429,10 +1411,9 @@ def test_update_key_request_validation(): def test_get_temp_budget_increase(): - from datetime import datetime, timedelta - - from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase + from litellm.proxy._types import UserAPIKeyAuth + from datetime import datetime, timedelta expiry = datetime.now() + timedelta(days=1) expiry_in_isoformat = expiry.isoformat() @@ -1449,12 +1430,11 @@ def test_get_temp_budget_increase(): def test_update_key_budget_with_temp_budget_increase(): - from datetime import datetime, timedelta - - from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import ( _update_key_budget_with_temp_budget_increase, ) + from litellm.proxy._types import UserAPIKeyAuth + from datetime import datetime, timedelta expiry = datetime.now() + timedelta(days=1) expiry_in_isoformat = expiry.isoformat() @@ -1470,7 +1450,7 @@ def test_update_key_budget_with_temp_budget_increase(): assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200 -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock, AsyncMock @pytest.mark.asyncio @@ -1511,18 +1491,17 @@ async def test_health_check_not_called_when_disabled(monkeypatch): }, ) def test_custom_openapi(mock_get_openapi_schema): - from litellm.proxy.proxy_server import app, custom_openapi + from litellm.proxy.proxy_server import custom_openapi + from litellm.proxy.proxy_server import app openapi_schema = custom_openapi() assert openapi_schema is not None +import pytest +from unittest.mock import MagicMock, AsyncMock import asyncio from datetime import timedelta -from unittest.mock import AsyncMock, MagicMock - -import pytest - from litellm.proxy.utils import ProxyUpdateSpend @@ -1644,227 +1623,6 @@ def test_provider_specific_header(): from litellm.proxy._types import LiteLLM_UserTable - -async def create_budget(session, data): - url = "http://0.0.0.0:4000/budget/new" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - - async with session.post(url, headers=headers, json=data) as response: - assert response.status == 200 - response_data = await response.json() - budget_id = response_data["budget_id"] - print(f"Created Budget {budget_id}") - return response_data - - -async def create_end_user(prisma_client, session, user_id, budget_id, spend=None): - url = "http://0.0.0.0:4000/end_user/new" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = { - "user_id": user_id, - "budget_id": budget_id, - } - - async with session.post(url, headers=headers, json=data) as response: - assert response.status == 200 - response_data = await response.json() - end_user_id = response_data["user_id"] - print(f"Created End User {end_user_id}") - - if spend is not None: - end_users = await prisma_client.get_data( - table_name="enduser", - query_type="find_all", - budget_id_list=[budget_id], - ) - end_user = [user for user in end_users if user.user_id == user_id][0] - end_user.spend = spend - await prisma_client.update_data( - query_type="update_many", - data_list=[end_user], - table_name="enduser", - ) - - return response_data - - -async def delete_budget(session, budget_id): - url = "http://0.0.0.0:4000/budget/delete" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = {"id": budget_id} - async with session.post(url, headers=headers, json=data) as response: - assert response.status == 200 - print(f"Deleted Budget {budget_id}") - - -async def delete_end_user(session, user_id): - url = "http://0.0.0.0:4000/end_user/delete" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = {"user_ids": [user_id]} - async with session.post(url, headers=headers, json=data) as response: - assert response.status == 200 - print(f"Deleted End User {user_id}") - - -@pytest.fixture -def prisma_client(): - proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) - prisma_client = PrismaClient( - database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj - ) - return prisma_client - - -@pytest_asyncio.fixture -async def budget_and_enduser_setup(prisma_client): - """ - Fixture to set up budgets and end users for testing and clean them up afterward. - - This fixture performs the following: - - Creates two budgets: - * Budget X with a short duration ("5s"). - * Budget Y with a long duration ("30d"). - - Stores the initial 'budget_reset_at' timestamps for both budgets. - - Creates three end users: - * End Users A and B are associated with Budget X and are given initial spend values. - * End User C is associated with Budget Y with an initial spend. - - After the test (after the yield), the created end users and budgets are deleted. - """ - await prisma_client.connect() - - async with aiohttp.ClientSession() as session: - # Create budgets - id_budget_x = f"budget-{uuid.uuid4()}" - data_budget_x = { - "budget_id": id_budget_x, - "budget_duration": "5s", - "max_budget": 2, - } - id_budget_y = f"budget-{uuid.uuid4()}" - data_budget_y = { - "budget_id": id_budget_y, - "budget_duration": "30d", - "max_budget": 1, - } - response_budget_x = await create_budget(session, data_budget_x) - initial_budget_x_reset_at_date = datetime.fromisoformat( - response_budget_x["budget_reset_at"] - ) - response_budget_y = await create_budget(session, data_budget_y) - initial_budget_y_reset_at_date = datetime.fromisoformat( - response_budget_y["budget_reset_at"] - ) - - # Create end users - id_end_user_a = f"test-user-{uuid.uuid4()}" - id_end_user_b = f"test-user-{uuid.uuid4()}" - id_end_user_c = f"test-user-{uuid.uuid4()}" - await create_end_user( - prisma_client, session, id_end_user_a, id_budget_x, spend=0.16 - ) - await create_end_user( - prisma_client, session, id_end_user_b, id_budget_x, spend=0.04 - ) - await create_end_user( - prisma_client, session, id_end_user_c, id_budget_y, spend=0.2 - ) - - # Bundle data needed for the test - setup_data = { - "budgets": { - "id_budget_x": id_budget_x, - "id_budget_y": id_budget_y, - "initial_budget_x_reset_at_date": initial_budget_x_reset_at_date, - "initial_budget_y_reset_at_date": initial_budget_y_reset_at_date, - }, - "end_users": { - "id_end_user_a": id_end_user_a, - "id_end_user_b": id_end_user_b, - "id_end_user_c": id_end_user_c, - }, - } - - # Provide the setup data to the test - yield setup_data - - # Clean-up: Delete the created test data - await delete_end_user(session, id_end_user_a) - await delete_end_user(session, id_end_user_b) - await delete_end_user(session, id_end_user_c) - await delete_budget(session, id_budget_x) - await delete_budget(session, id_budget_y) - - -@pytest.mark.asyncio -async def test_reset_budget_for_endusers(prisma_client, budget_and_enduser_setup): - """ - Test the part "Reset End-User (Customer) Spend and corresponding Budget duration" in reset_budget function. - - This test uses the budget_and_enduser_setup fixture to create budgets and end users, - waits for the short-duration budget to expire, calls reset_budget, and verifies that: - - End users associated with the short-duration budget X have their spend reset to 0. - - The budget_reset_at timestamp for the short-duration budget X is updated, - while the long-duration budget Y remains unchanged. - """ - - # Unpack the required data from the fixture - budgets = budget_and_enduser_setup["budgets"] - end_users = budget_and_enduser_setup["end_users"] - - id_budget_x = budgets["id_budget_x"] - id_budget_y = budgets["id_budget_y"] - initial_budget_x_reset_at_date = budgets["initial_budget_x_reset_at_date"] - initial_budget_y_reset_at_date = budgets["initial_budget_y_reset_at_date"] - - id_end_user_a = end_users["id_end_user_a"] - id_end_user_b = end_users["id_end_user_b"] - id_end_user_c = end_users["id_end_user_c"] - - # Wait for Budget X to expire (short duration "5s" plus a small buffer) - await asyncio.sleep(6) - - # Call the reset_budget function: - # It should reset the spend values for end users associated with Budget X. - await reset_budget(prisma_client) - - # Retrieve updated data for end users - updated_end_users = await prisma_client.get_data( - table_name="enduser", - query_type="find_all", - budget_id_list=[id_budget_x, id_budget_y], - ) - # Retrieve updated data for budgets - updated_budgets = await prisma_client.get_data( - table_name="budget", - query_type="find_all", - reset_at=datetime.now(timezone.utc) + timedelta(days=31), - ) - - # Assertions for end users - user_a = [user for user in updated_end_users if user.user_id == id_end_user_a][0] - user_b = [user for user in updated_end_users if user.user_id == id_end_user_b][0] - user_c = [user for user in updated_end_users if user.user_id == id_end_user_c][0] - - assert user_a.spend == 0, "Spend for end_user_a was not reset to 0" - assert user_b.spend == 0, "Spend for end_user_b was not reset to 0" - assert user_c.spend > 0, "Spend for end_user_c should not be reset" - - # Assertions for budgets - budget_x = [ - budget for budget in updated_budgets if budget.budget_id == id_budget_x - ][0] - budget_y = [ - budget for budget in updated_budgets if budget.budget_id == id_budget_y - ][0] - - assert ( - budget_x.budget_reset_at > initial_budget_x_reset_at_date - ), "Budget X budget_reset_at was not updated" - assert ( - budget_y.budget_reset_at == initial_budget_y_reset_at_date - ), "Budget Y budget_reset_at should remain unchanged" - - @pytest.mark.parametrize( "wildcard_model, expected_models", [ @@ -2061,4 +1819,4 @@ async def test_get_admin_team_ids( where={"team_id": {"in": user_info.teams}} ) else: - mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called() + mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called() \ No newline at end of file From d260037de4e9570195c427e8dbdf84190f956995 Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Fri, 21 Mar 2025 09:37:45 +0100 Subject: [PATCH 07/10] formatting --- litellm/proxy/_types.py | 69 ++++---- .../proxy/common_utils/reset_budget_job.py | 54 +++--- .../common_utils/test_reset_budget_job.py | 15 +- .../test_proxy_budget_reset.py | 161 +++++++++++------- 4 files changed, 181 insertions(+), 118 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9d38cf239a..6d817662ab 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -596,9 +596,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[dict] = ( - {} - ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[ + dict + ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -858,12 +858,12 @@ class NewCustomerRequest(BudgetNewRequest): alias: Optional[str] = None # human-friendly alias blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model @model_validator(mode="before") @classmethod @@ -885,12 +885,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1025,9 +1025,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( - "success_and_failure" - ) + callback_type: Optional[ + Literal["success", "failure", "success_and_failure"] + ] = "success_and_failure" callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1151,11 +1151,14 @@ class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase): model_config = ConfigDict(protected_namespaces=()) + class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable): """Represents all params for a LiteLLM_BudgetTable record""" + budget_id: str created_at: datetime + class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): """ Used to track spend of a user_id within a team_id @@ -1289,9 +1292,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[List[FieldDetail]] = ( - None # For nested dictionary or Pydantic fields - ) + nested_fields: Optional[ + List[FieldDetail] + ] = None # For nested dictionary or Pydantic fields class ConfigGeneralSettings(LiteLLMPydanticObjectBase): @@ -1557,9 +1560,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) + user: Optional[ + Any + ] = None # You might want to replace 'Any' with a more specific type if available litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -2240,9 +2243,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[float] = ( - None # Users max budget within the organization - ) + max_budget_in_organization: Optional[ + float + ] = None # Users max budget within the organization class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -2431,9 +2434,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[str, ProviderBudgetResponseObject] = ( - {} - ) # Dictionary mapping provider names to their budget configurations + providers: Dict[ + str, ProviderBudgetResponseObject + ] = {} # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -2561,9 +2564,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[str] = ( - None # can be either user / team, inferred from the role mapping - ) + object_id_jwt_field: Optional[ + str + ] = None # can be either user / team, inferred from the role mapping scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index c40631f002..0b82eecd9d 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -7,11 +7,11 @@ from typing import List, Literal, Optional, Union from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import ( + LiteLLM_BudgetTableFull, + LiteLLM_EndUserTable, LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken, - LiteLLM_EndUserTable, - LiteLLM_BudgetTableFull ) from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.types.services import ServiceTypes @@ -49,7 +49,6 @@ class ResetBudgetJob: ### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ### await self.reset_budget_for_litellm_endusers() - async def reset_budget_for_litellm_endusers(self): """ Resets the budget for all LiteLLM End-Users (Customers) if their budget has expired @@ -58,7 +57,7 @@ class ResetBudgetJob: now = datetime.now(timezone.utc) start_time = time.time() endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None - budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None + budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None updated_endusers: List[LiteLLM_EndUserTable] = [] failed_endusers = [] try: @@ -68,7 +67,9 @@ class ResetBudgetJob: if budgets_to_reset is not None and len(budgets_to_reset) > 0: for budget in budgets_to_reset: - budget = await ResetBudgetJob._reset_budget_reset_at_date(budget, now) + budget = await ResetBudgetJob._reset_budget_reset_at_date( + budget, now + ) await self.prisma_client.update_data( query_type="update_many", data_list=budgets_to_reset, @@ -78,15 +79,16 @@ class ResetBudgetJob: endusers_to_reset = await self.prisma_client.get_data( table_name="enduser", query_type="find_all", - budget_id_list=[budget.budget_id for budget in budgets_to_reset] + budget_id_list=[budget.budget_id for budget in budgets_to_reset], ) if endusers_to_reset is not None and len(endusers_to_reset) > 0: - for enduser in endusers_to_reset: try: - updated_enduser = await ResetBudgetJob._reset_budget_for_enduser( - enduser=enduser + updated_enduser = ( + await ResetBudgetJob._reset_budget_for_enduser( + enduser=enduser + ) ) if updated_enduser is not None: updated_endusers.append(updated_enduser) @@ -104,7 +106,8 @@ class ResetBudgetJob: ) verbose_proxy_logger.debug( - "Updated users %s", json.dumps(updated_endusers, indent=4, default=str) + "Updated users %s", + json.dumps(updated_endusers, indent=4, default=str), ) await self.prisma_client.update_data( @@ -127,11 +130,15 @@ class ResetBudgetJob: start_time=start_time, end_time=end_time, event_metadata={ - "num_budgets_found": len(budgets_to_reset) if budgets_to_reset else 0, + "num_budgets_found": len(budgets_to_reset) + if budgets_to_reset + else 0, "budgets_found": json.dumps( budgets_to_reset, indent=4, default=str ), - "num_endusers_found": len(endusers_to_reset) if endusers_to_reset else 0, + "num_endusers_found": len(endusers_to_reset) + if endusers_to_reset + else 0, "endusers_found": json.dumps( endusers_to_reset, indent=4, default=str ), @@ -140,7 +147,9 @@ class ResetBudgetJob: updated_endusers, indent=4, default=str ), "num_endusers_failed": len(failed_endusers), - "endusers_failed": json.dumps(failed_endusers, indent=4, default=str), + "endusers_failed": json.dumps( + failed_endusers, indent=4, default=str + ), }, ) ) @@ -155,20 +164,23 @@ class ResetBudgetJob: start_time=start_time, end_time=end_time, event_metadata={ - "num_budgets_found": len(budgets_to_reset) if budgets_to_reset else 0, + "num_budgets_found": len(budgets_to_reset) + if budgets_to_reset + else 0, "budgets_found": json.dumps( budgets_to_reset, indent=4, default=str ), - "num_endusers_found": len(endusers_to_reset) if endusers_to_reset else 0, + "num_endusers_found": len(endusers_to_reset) + if endusers_to_reset + else 0, "endusers_found": json.dumps( endusers_to_reset, indent=4, default=str - ) + ), }, ) ) verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e) - async def reset_budget_for_litellm_keys(self): """ Resets the budget for all the litellm keys @@ -482,7 +494,7 @@ class ResetBudgetJob: @staticmethod async def _reset_budget_for_enduser( - enduser: LiteLLM_EndUserTable + enduser: LiteLLM_EndUserTable, ) -> Optional[LiteLLM_EndUserTable]: try: enduser.spend = 0.0 @@ -492,7 +504,7 @@ class ResetBudgetJob: ) raise e return enduser - + @staticmethod async def _reset_budget_reset_at_date( budget: LiteLLM_BudgetTableFull, current_time: datetime @@ -510,7 +522,9 @@ class ResetBudgetJob: seconds=duration_s ) else: - budget.budget_reset_at = current_time + timedelta(seconds=duration_s) + budget.budget_reset_at = current_time + timedelta( + seconds=duration_s + ) except Exception as e: verbose_proxy_logger.exception( "Error resetting budget_reset_at for budget: %s. Item: %s", e, budget diff --git a/tests/litellm/proxy/common_utils/test_reset_budget_job.py b/tests/litellm/proxy/common_utils/test_reset_budget_job.py index d1e359554f..886ca34b71 100644 --- a/tests/litellm/proxy/common_utils/test_reset_budget_job.py +++ b/tests/litellm/proxy/common_utils/test_reset_budget_job.py @@ -1,12 +1,9 @@ import asyncio -import json import os import sys -import time -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone import pytest -from fastapi.testclient import TestClient sys.path.insert( 0, os.path.abspath("../../..") @@ -21,7 +18,13 @@ from litellm.proxy.utils import ProxyLogging class MockPrismaClient: def __init__(self): self.data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []} - self.updated_data = {"key": [], "user": [], "team": [], "budget": [], "enduser": []} + self.updated_data = { + "key": [], + "user": [], + "team": [], + "budget": [], + "enduser": [], + } async def get_data(self, table_name, query_type, **kwargs): return self.data.get(table_name, []) @@ -145,7 +148,6 @@ def test_reset_budget_for_team(reset_budget_job, mock_prisma_client): assert updated_team.budget_reset_at > now - def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client): # Setup test data now = datetime.now(timezone.utc) @@ -184,6 +186,7 @@ def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client): assert updated_enduser.spend == 0.0 assert updated_budget.budget_reset_at > now + def test_reset_budget_all(reset_budget_job, mock_prisma_client): # Setup test data now = datetime.now(timezone.utc) diff --git a/tests/litellm_utils_tests/test_proxy_budget_reset.py b/tests/litellm_utils_tests/test_proxy_budget_reset.py index 4fbde76f9a..a3c9523941 100644 --- a/tests/litellm_utils_tests/test_proxy_budget_reset.py +++ b/tests/litellm_utils_tests/test_proxy_budget_reset.py @@ -1,34 +1,22 @@ +import asyncio import os import sys -import time -import traceback -import uuid from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from dotenv import load_dotenv -import json -import asyncio load_dotenv() import os -import tempfile -from uuid import uuid4 + from litellm.proxy._types import LiteLLM_BudgetTableFull -from litellm.litellm_core_utils.duration_parser import duration_in_seconds sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob -from litellm.proxy._types import ( - LiteLLM_VerificationToken, - LiteLLM_UserTable, - LiteLLM_TeamTable, -) -from litellm.types.services import ServiceTypes # Note: In our "fake" items we use dicts with fields that our fake reset functions modify. # In a real-world scenario, these would be instances of LiteLLM_VerificationToken, LiteLLM_UserTable, etc. @@ -182,7 +170,6 @@ async def test_reset_budget_users_partial_failure(): ) - @pytest.mark.asyncio async def test_reset_budget_endusers_partial_failure(): """ @@ -194,13 +181,40 @@ async def test_reset_budget_endusers_partial_failure(): "spend": 20.0, "budget_id": "budget1", } # Will trigger simulated failure - user2 = {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"} # Should be updated - user3 = {"user_id": "user3", "spend": 30.0, "budget_id": "budget1"} # Should be updated - user4 = {"user_id": "user4", "spend": 35.0, "budget_id": "budget1"} # Should be updated - user5 = {"user_id": "user5", "spend": 40.0, "budget_id": "budget1"} # Should be updated - user6 = {"user_id": "user6", "spend": 45.0, "budget_id": "budget1"} # Should be updated - - budget1 = LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + user2 = { + "user_id": "user2", + "spend": 25.0, + "budget_id": "budget1", + } # Should be updated + user3 = { + "user_id": "user3", + "spend": 30.0, + "budget_id": "budget1", + } # Should be updated + user4 = { + "user_id": "user4", + "spend": 35.0, + "budget_id": "budget1", + } # Should be updated + user5 = { + "user_id": "user5", + "spend": 40.0, + "budget_id": "budget1", + } # Should be updated + user6 = { + "user_id": "user6", + "spend": 45.0, + "budget_id": "budget1", + } # Should be updated + + budget1 = LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) prisma_client = MagicMock() @@ -211,7 +225,6 @@ async def test_reset_budget_endusers_partial_failure(): return [user1, user2, user3, user4, user5, user6] return [] - prisma_client.get_data = AsyncMock() prisma_client.get_data.side_effect = get_data_mock @@ -229,7 +242,7 @@ async def test_reset_budget_endusers_partial_failure(): raise Exception("Simulated failure for user1") enduser["spend"] = 0.0 return enduser - + with patch.object( ResetBudgetJob, "_reset_budget_for_enduser", side_effect=fake_reset_enduser ) as mock_reset_enduser: @@ -256,6 +269,7 @@ async def test_reset_budget_endusers_partial_failure(): for call in failure_hook_calls ) + @pytest.mark.asyncio async def test_reset_budget_teams_partial_failure(): """ @@ -339,8 +353,15 @@ async def test_reset_budget_continues_other_categories_on_failure(): user2 = {"id": "user2", "spend": 25.0, "budget_duration": 120} # Succeeds team1 = {"id": "team1", "spend": 30.0, "budget_duration": 180} team2 = {"id": "team2", "spend": 35.0, "budget_duration": 180} - enduser1 = {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"} - budget1 = LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + enduser1 = {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"} + budget1 = LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) prisma_client = MagicMock() @@ -389,11 +410,11 @@ async def test_reset_budget_continues_other_categories_on_failure(): current_time + timedelta(seconds=team["budget_duration"]) ).isoformat() return team - + async def fake_reset_enduser(enduser): enduser["spend"] = 0.0 return enduser - + with patch.object( ResetBudgetJob, "_reset_budget_for_key", side_effect=fake_reset_key ) as mock_reset_key, patch.object( @@ -439,6 +460,7 @@ async def test_reset_budget_continues_other_categories_on_failure(): assert enduser_call.kwargs.get("table_name") == "enduser" assert len(enduser_call.kwargs.get("data_list", [])) == 1 + # --------------------------------------------------------------------------- # Additional tests for service logger behavior (keys, users, teams, endusers) # --------------------------------------------------------------------------- @@ -487,9 +509,10 @@ async def test_service_logger_keys_success(): # Verify success hook call proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_keys_found") == len(keys) assert event_metadata.get("num_keys_updated") == len(keys) @@ -548,9 +571,10 @@ async def test_service_logger_keys_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_keys_found") == len(keys) keys_found_str = event_metadata.get("keys_found", "") @@ -600,9 +624,10 @@ async def test_service_logger_users_success(): mock_verbose_exc.assert_not_called() proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_users_found") == len(users) assert event_metadata.get("num_users_updated") == len(users) @@ -659,9 +684,10 @@ async def test_service_logger_users_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_users_found") == len(users) users_found_str = event_metadata.get("users_found", "") @@ -710,9 +736,10 @@ async def test_service_logger_teams_success(): mock_verbose_exc.assert_not_called() proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_teams_found") == len(teams) assert event_metadata.get("num_teams_updated") == len(teams) @@ -769,9 +796,10 @@ async def test_service_logger_teams_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_teams_found") == len(teams) teams_found_str = event_metadata.get("teams_found", "") @@ -779,7 +807,6 @@ async def test_service_logger_teams_failure(): proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_not_called() - @pytest.mark.asyncio async def test_service_logger_endusers_success(): """ @@ -787,11 +814,18 @@ async def test_service_logger_endusers_success(): the correct metadata and no exception is logged. """ endusers = [ - {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, - {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"} + {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, + {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"}, ] budgets = [ - LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) ] async def fake_get_data(*, table_name, query_type, **kwargs): @@ -829,9 +863,10 @@ async def test_service_logger_endusers_success(): mock_verbose_exc.assert_not_called() proxy_logging_obj.service_logging_obj.async_service_success_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_success_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_budgets_found") == len(budgets) assert event_metadata.get("num_endusers_found") == len(endusers) @@ -847,11 +882,18 @@ async def test_service_logger_users_failure(): logs the exception, and does not call the success hook. """ endusers = [ - {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, - {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"} + {"user_id": "user1", "spend": 25.0, "budget_id": "budget1"}, + {"user_id": "user2", "spend": 25.0, "budget_id": "budget1"}, ] budgets = [ - LiteLLM_BudgetTableFull(**{"budget_id": "budget1", "max_budget": 65.0, "budget_duration": "2d", "created_at": datetime.now(timezone.utc) - timedelta(days=3)}) + LiteLLM_BudgetTableFull( + **{ + "budget_id": "budget1", + "max_budget": 65.0, + "budget_duration": "2d", + "created_at": datetime.now(timezone.utc) - timedelta(days=3), + } + ) ] async def fake_get_data(*, table_name, query_type, **kwargs): @@ -897,9 +939,10 @@ async def test_service_logger_users_failure(): ) proxy_logging_obj.service_logging_obj.async_service_failure_hook.assert_called_once() - args, kwargs = ( - proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args - ) + ( + args, + kwargs, + ) = proxy_logging_obj.service_logging_obj.async_service_failure_hook.call_args event_metadata = kwargs.get("event_metadata", {}) assert event_metadata.get("num_budgets_found") == len(budgets) assert event_metadata.get("num_endusers_found") == len(endusers) From 9cd347c1f2c6557ec78fb6ca45e87e9163d3e9ab Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Wed, 23 Apr 2025 14:31:28 +0200 Subject: [PATCH 08/10] remove unused imports --- litellm/proxy/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ea590f58d8..7fb4ed9ebe 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -7,7 +7,7 @@ import smtplib import threading import time import traceback -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import ( @@ -23,7 +23,6 @@ from typing import ( ) from litellm.constants import MAX_TEAM_LIST_LIMIT -from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, CommonProxyErrors, From 16ac6306a9e3868355eb3600ecedb87dfc73a20b Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Wed, 23 Apr 2025 15:05:46 +0200 Subject: [PATCH 09/10] load env in test --- .../test_integration_reset_enduser_spend.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py b/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py index 8c242e7e5c..999584451b 100644 --- a/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py +++ b/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py @@ -1,24 +1,22 @@ import asyncio import os import sys -import time import uuid from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional -from unittest.mock import AsyncMock, MagicMock, Mock, patch -from litellm.caching.caching import DualCache, RedisCache import aiohttp import pytest import pytest_asyncio -from fastapi import Request -from fastapi.testclient import TestClient +from dotenv import load_dotenv + +from litellm.caching.caching import DualCache + +load_dotenv() sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path -from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob from litellm.proxy.utils import PrismaClient, ProxyLogging From 587723996f0af2836a52a56cd899317bc80dbf5a Mon Sep 17 00:00:00 2001 From: Laurien Lummer Date: Wed, 23 Apr 2025 20:44:40 +0200 Subject: [PATCH 10/10] move file to local testing --- .../test_enduser_spend_reset.py} | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) rename tests/{litellm/proxy/common_utils/test_integration_reset_enduser_spend.py => local_testing/test_enduser_spend_reset.py} (99%) diff --git a/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py b/tests/local_testing/test_enduser_spend_reset.py similarity index 99% rename from tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py rename to tests/local_testing/test_enduser_spend_reset.py index 999584451b..832ba9b324 100644 --- a/tests/litellm/proxy/common_utils/test_integration_reset_enduser_spend.py +++ b/tests/local_testing/test_enduser_spend_reset.py @@ -10,16 +10,15 @@ import pytest_asyncio from dotenv import load_dotenv from litellm.caching.caching import DualCache - -load_dotenv() - -sys.path.insert( - 0, os.path.abspath("../../..") -) # Adds the parent directory to the system path - from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob from litellm.proxy.utils import PrismaClient, ProxyLogging +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +load_dotenv() + async def create_budget(session, data): url = "http://0.0.0.0:4000/budget/new"