diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index ede3624da6..9d5a12a036 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -1007,7 +1007,24 @@ async def get_key_object( code=status.HTTP_401_UNAUTHORIZED, ) - _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) + token_data = _valid_token.model_dump(exclude_none=True) + + # Manually map the budget from the joined budget table if it exists + # This ensures the max_budget from LiteLLM_BudgetTable takes precedence if set + if token_data.get("litellm_budget_table_max_budget") is not None: + token_data["max_budget"] = token_data.pop("litellm_budget_table_max_budget") + if token_data.get("litellm_budget_table_soft_budget") is not None: + token_data["soft_budget"] = token_data.pop("litellm_budget_table_soft_budget") + # Only override if budget table value is explicitly set (not None) and > 0 + budget_tpm_limit = token_data.pop("litellm_budget_table_tpm_limit", None) + if budget_tpm_limit is not None and budget_tpm_limit > 0: + token_data["tpm_limit"] = budget_tpm_limit + budget_rpm_limit = token_data.pop("litellm_budget_table_rpm_limit", None) + if budget_rpm_limit is not None and budget_rpm_limit > 0: + token_data["rpm_limit"] = budget_rpm_limit + + + _response = UserAPIKeyAuth(**token_data) # save the key object to cache await _cache_key_object( diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py index 4b59f603d3..7fe6be98ef 100644 --- a/litellm/proxy/hooks/max_budget_limiter.py +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -21,29 +21,37 @@ class _PROXY_MaxBudgetLimiter(CustomLogger): ): try: verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook") - cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" - user_row = await cache.async_get_cache( - cache_key, parent_otel_span=user_api_key_dict.parent_otel_span - ) - if user_row is None: # value not yet cached - return - max_budget = user_row["max_budget"] - curr_spend = user_row["spend"] + + # Use the budget information directly from the validated user_api_key_dict + max_budget = user_api_key_dict.max_budget + curr_spend = user_api_key_dict.spend if max_budget is None: + # No budget limit set for this key/user/team return if curr_spend is None: - return + # If spend tracking hasn't started, assume 0 + curr_spend = 0.0 # CHECK IF REQUEST ALLOWED if curr_spend >= max_budget: + verbose_proxy_logger.info( + f"Budget Limit Reached for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}" + ) raise HTTPException(status_code=429, detail="Max budget limit reached.") + else: + verbose_proxy_logger.debug( + f"Budget Check Passed for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}" + ) + except HTTPException as e: + # Re-raise HTTPException to ensure FastAPI handles it correctly raise e except Exception as e: verbose_logger.exception( - "litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format( + "litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occurred - {}".format( str(e) ) ) + pass diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 242c013d67..2cc553d4ec 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -195,10 +195,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "global_max_parallel_requests", None ) tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) - if tpm_limit is None: + if tpm_limit is None or tpm_limit == 0: # Treat 0 as no limit tpm_limit = sys.maxsize rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) - if rpm_limit is None: + if rpm_limit is None or rpm_limit == 0: # Treat 0 as no limit rpm_limit = sys.maxsize values_to_update_in_cache: List[ @@ -310,9 +310,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if _model is not None: if _tpm_limit_for_key_model: tpm_limit_for_model = _tpm_limit_for_key_model.get(_model) + if tpm_limit_for_model == 0: # Treat 0 as no limit + tpm_limit_for_model = sys.maxsize if _rpm_limit_for_key_model: rpm_limit_for_model = _rpm_limit_for_key_model.get(_model) + if rpm_limit_for_model == 0: # Treat 0 as no limit + rpm_limit_for_model = sys.maxsize new_val = await self.check_key_in_limits( user_api_key_dict=user_api_key_dict, @@ -350,9 +354,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if user_id is not None: user_tpm_limit = user_api_key_dict.user_tpm_limit user_rpm_limit = user_api_key_dict.user_rpm_limit - if user_tpm_limit is None: + if user_tpm_limit is None or user_tpm_limit == 0: # Treat 0 as no limit user_tpm_limit = sys.maxsize - if user_rpm_limit is None: + if user_rpm_limit is None or user_rpm_limit == 0: # Treat 0 as no limit user_rpm_limit = sys.maxsize request_count_api_key = f"{user_id}::{precise_minute}::request_count" @@ -378,9 +382,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): team_tpm_limit = user_api_key_dict.team_tpm_limit team_rpm_limit = user_api_key_dict.team_rpm_limit - if team_tpm_limit is None: + if team_tpm_limit is None or team_tpm_limit == 0: # Treat 0 as no limit team_tpm_limit = sys.maxsize - if team_rpm_limit is None: + if team_rpm_limit is None or team_rpm_limit == 0: # Treat 0 as no limit team_rpm_limit = sys.maxsize request_count_api_key = f"{team_id}::{precise_minute}::request_count" @@ -409,9 +413,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_api_key_dict, "end_user_rpm_limit", sys.maxsize ) - if end_user_tpm_limit is None: + if ( + end_user_tpm_limit is None or end_user_tpm_limit == 0 + ): # Treat 0 as no limit end_user_tpm_limit = sys.maxsize - if end_user_rpm_limit is None: + if ( + end_user_rpm_limit is None or end_user_rpm_limit == 0 + ): # Treat 0 as no limit end_user_rpm_limit = sys.maxsize # now do the same tpm/rpm checks diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 097f798de2..da18e5bcab 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -491,12 +491,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ) ) + # Make a copy of the data *before* adding the proxy_server_request key + original_data_copy = copy.deepcopy(data) # Use deepcopy for nested structures + # Include original request and headers in the data data["proxy_server_request"] = { "url": str(request.url), "method": request.method, "headers": _headers, - "body": copy.copy(data), # use copy instead of deepcopy + "body": original_data_copy, # Use the deep copy without the circular reference } ## Dynamic api version (Azure OpenAI endpoints) ## diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 7695306c87..713b751216 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -3,6 +3,7 @@ import sys, os, asyncio, time, random, uuid import traceback +from datetime import datetime, timedelta from dotenv import load_dotenv load_dotenv() @@ -21,6 +22,7 @@ from litellm.proxy._types import ( LiteLLM_BudgetTable, LiteLLM_UserTable, LiteLLM_TeamTable, + LiteLLM_VerificationTokenView, ) from litellm.proxy.utils import PrismaClient from litellm.proxy.auth.auth_checks import ( @@ -29,6 +31,8 @@ from litellm.proxy.auth.auth_checks import ( ) from litellm.proxy.utils import ProxyLogging from litellm.proxy.utils import CallInfo +from pydantic import BaseModel +from unittest.mock import AsyncMock, MagicMock, patch @pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)]) @@ -255,7 +259,216 @@ async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_w llm_router=router, ) - print(e) + +# Mock ProxyLogging for budget alert testing +class MockProxyLogging: + def __init__(self): + self.alert_triggered = False + self.alert_type = None + self.user_info = None + + async def budget_alerts(self, type, user_info): + self.alert_triggered = True + self.alert_type = type + self.user_info = user_info + + # Add dummy methods for other required ProxyLogging methods if needed + async def pre_call_hook(self, *args, **kwargs): + pass + + async def post_call_failure_hook(self, *args, **kwargs): + pass + + async def post_call_success_hook(self, *args, **kwargs): + pass + + async def async_post_call_streaming_hook(self, *args, **kwargs): + pass + + def async_post_call_streaming_iterator_hook(self, response, *args, **kwargs): + return response + + def _init_response_taking_too_long_task(self, *args, **kwargs): + pass + + async def update_request_status(self, *args, **kwargs): + pass + + async def failed_tracking_alert(self, *args, **kwargs): + pass + + async def alerting_handler(self, *args, **kwargs): + pass + + async def failure_handler(self, *args, **kwargs): + pass + + +@pytest.mark.parametrize( + "token_spend, max_budget, expect_error", + [ + (5.0, 10.0, False), # Under budget + (9.99, 10.0, False), # Just under budget + (0.0, 0.0, True), # At zero budget + (0.0, None, False), # No budget set + (10.0, 10.0, True), # At budget limit + (15.0, 10.0, True), # Over budget + (None, 10.0, False), # Spend not tracked yet + ], +) +@pytest.mark.asyncio +async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error): + """ + Test the _PROXY_MaxBudgetLimiter pre-call hook directly. + This test verifies the fix applied in the hook itself. + """ + from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter + from litellm.caching.caching import DualCache + from fastapi import HTTPException + + limiter = _PROXY_MaxBudgetLimiter() + mock_cache = ( + DualCache() + ) # The hook expects a cache object, even if not used in the updated logic + + # Ensure spend is a float, defaulting to 0.0 if None + actual_spend = token_spend if token_spend is not None else 0.0 + + user_api_key_dict = UserAPIKeyAuth( + token="test-token-hook", + spend=actual_spend, + max_budget=max_budget, + user_id="test-user-hook", + ) + + mock_data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]} + + try: + await limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=mock_cache, + data=mock_data, + call_type="completion", + ) + if expect_error: + pytest.fail( + f"Expected HTTPException for spend={token_spend}, max_budget={max_budget}" + ) + except HTTPException as e: + if not expect_error: + pytest.fail( + f"Unexpected HTTPException for spend={token_spend}, max_budget={max_budget}: {e.detail}" + ) + assert e.status_code == 429 + assert "Max budget limit reached" in e.detail + except Exception as e: + pytest.fail(f"Unexpected exception type {type(e).__name__} raised: {e}") + + +@pytest.mark.asyncio +async def test_get_key_object_loads_budget_table_limits(): + """ + Test if get_key_object correctly loads max_budget, tpm_limit, and rpm_limit + from the joined LiteLLM_BudgetTable when a budget_id is present on the key. + """ + from litellm.proxy.auth.auth_checks import get_key_object + from unittest.mock import AsyncMock, MagicMock, patch + from litellm.proxy.utils import PrismaClient, ProxyLogging + from litellm.caching.caching import DualCache + + # Mock Prisma response simulating the joined view + mock_db_response_dict = { + "token": "hashed_test_token_budget_table", + "key_name": "sk-...test", + "key_alias": "test-budget-key", + "spend": 5.0, + "max_budget": None, # Budget on token table itself is None + "expires": None, + "models": [], + "aliases": {}, + "config": {}, + "user_id": "test-user-budget", + "team_id": None, + "max_parallel_requests": None, + "metadata": {}, + "tpm_limit": None, # Limit on token table itself is None + "rpm_limit": None, # Limit on token table itself is None + "budget_duration": None, + "budget_reset_at": None, + "allowed_cache_controls": [], + "permissions": {}, + "model_spend": {}, + "model_max_budget": {}, + "soft_budget_cooldown": False, + "blocked": False, + "org_id": None, + "created_at": datetime.now(), + "updated_at": datetime.now(), + "created_by": None, + "updated_by": None, + "team_spend": None, + "team_alias": None, + "team_tpm_limit": None, + "team_rpm_limit": None, + "team_max_budget": None, + "team_models": [], + "team_blocked": False, + "team_model_aliases": None, + "team_member_spend": None, + "team_member": None, + "team_metadata": None, + "budget_id": "budget_123", # Link to budget table + # Values coming from the joined LiteLLM_BudgetTable + "litellm_budget_table_max_budget": 20.0, + "litellm_budget_table_soft_budget": 15.0, + "litellm_budget_table_tpm_limit": 1000, + "litellm_budget_table_rpm_limit": 100, + "litellm_budget_table_model_max_budget": {"gpt-4": 5.0}, + } + + # Mock PrismaClient and its methods + mock_prisma_client = MagicMock(spec=PrismaClient) + + # Create a mock object that mimics the structure returned by prisma.get_data for the raw query + mock_db_result = MagicMock() + for key, value in mock_db_response_dict.items(): + setattr(mock_db_result, key, value) + + # Add a model_dump method to the mock object + mock_db_result.model_dump = MagicMock(return_value=mock_db_response_dict) + + # Mock the get_data method to return our simulated DB response object + mock_prisma_client.get_data = AsyncMock(return_value=mock_db_result) + + mock_cache = DualCache() + mock_proxy_logging = MockProxyLogging() # Use the mock defined earlier + + # Call get_key_object + user_auth_obj = await get_key_object( + hashed_token="hashed_test_token_budget_table", + prisma_client=mock_prisma_client, + user_api_key_cache=mock_cache, + proxy_logging_obj=mock_proxy_logging, + ) + + # Assertions + assert ( + user_auth_obj.max_budget == 20.0 + ), "max_budget should be loaded from budget table" + assert ( + user_auth_obj.soft_budget == 15.0 + ), "soft_budget should be loaded from budget table" + assert ( + user_auth_obj.tpm_limit == 1000 + ), "tpm_limit should be loaded from budget table" + assert ( + user_auth_obj.rpm_limit == 100 + ), "rpm_limit should be loaded from budget table" + assert user_auth_obj.model_max_budget == { + "gpt-4": 5.0 + }, "model_max_budget should be loaded from budget table" + # Ensure original values from token table are not used if budget table values exist + assert user_auth_obj.spend == 5.0 # Spend comes from the token table itself @pytest.mark.asyncio @@ -319,13 +532,11 @@ async def test_virtual_key_max_budget_check( user_obj = LiteLLM_UserTable( user_id="test-user", - user_email="test@email.com", + user_email="test@example.com", max_budget=None, ) - proxy_logging_obj = ProxyLogging( - user_api_key_cache=None, - ) + proxy_logging_obj = MockProxyLogging() # Track if budget alert was called alert_called = False @@ -356,7 +567,6 @@ async def test_virtual_key_max_budget_check( await asyncio.sleep(1) - # Verify budget alert was triggered assert alert_called, "Budget alert should be triggered" @@ -477,6 +687,89 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert): ), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}" +@pytest.mark.parametrize( + "token_spend, max_budget_from_table, expect_budget_error", + [ + (5.0, 10.0, False), # Under budget + (10.0, 10.0, True), # At budget limit + (15.0, 10.0, True), # Over budget + (5.0, None, False), # No budget set in table + ], +) +@pytest.mark.asyncio +async def test_virtual_key_max_budget_check_from_budget_table( + token_spend, max_budget_from_table, expect_budget_error +): + """ + Test if virtual key budget checks work when max_budget is derived + from the joined LiteLLM_BudgetTable data. + """ + from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check + from litellm.proxy.utils import ProxyLogging + + # Setup test data - Simulate data structure after get_key_object fix + valid_token = UserAPIKeyAuth( + token="test-token-from-table", + spend=token_spend, + max_budget=max_budget_from_table, # This now reflects the budget from the table + user_id="test-user-table", + key_alias="test-key-table", + # Simulate that litellm_budget_table was present during the join + litellm_budget_table={ + "max_budget": max_budget_from_table, + "soft_budget": None, # Assuming soft_budget is not the focus here + # Add other necessary fields if the model requires them + } + if max_budget_from_table is not None + else None, + ) + + user_obj = LiteLLM_UserTable( + user_id="test-user-table", + user_email="test-table@example.com", + max_budget=None, # Ensure user-level budget doesn't interfere + ) + + proxy_logging_obj = MockProxyLogging() # Use the mock class defined above + + # Track if budget alert was called + alert_called = False + + async def mock_budget_alert(*args, **kwargs): + nonlocal alert_called + alert_called = True + + proxy_logging_obj.budget_alerts = mock_budget_alert + + try: + await _virtual_key_max_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + user_obj=user_obj, + ) + if expect_budget_error: + pytest.fail( + f"Expected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}" + ) + except litellm.BudgetExceededError as e: + if not expect_budget_error: + pytest.fail( + f"Unexpected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}" + ) + assert e.current_cost == token_spend + assert e.max_budget == max_budget_from_table + + await asyncio.sleep(0.1) # Allow time for alert task + + # Verify budget alert was triggered only if there was a budget + if max_budget_from_table is not None: + assert alert_called, "Budget alert should be triggered when max_budget is set" + else: + assert ( + not alert_called + ), "Budget alert should not be triggered when max_budget is None" + + @pytest.mark.asyncio async def test_can_user_call_model(): from litellm.proxy.auth.auth_checks import can_user_call_model