fix(budgets): fix rejecting requests when user over limits

This commit is contained in:
Sam 2025-04-01 09:24:25 +11:00
parent ad4aca02b2
commit bffeeb4f2e
4 changed files with 157 additions and 22 deletions

View file

@ -1007,6 +1007,14 @@ async def get_key_object(
token_data["max_budget"] = token_data.pop("litellm_budget_table_max_budget")
if token_data.get("litellm_budget_table_soft_budget") is not None:
token_data["soft_budget"] = token_data.pop("litellm_budget_table_soft_budget")
# Only override if budget table value is explicitly set (not None) and > 0
budget_tpm_limit = token_data.pop("litellm_budget_table_tpm_limit", None)
if budget_tpm_limit is not None and budget_tpm_limit > 0:
token_data["tpm_limit"] = budget_tpm_limit
budget_rpm_limit = token_data.pop("litellm_budget_table_rpm_limit", None)
if budget_rpm_limit is not None and budget_rpm_limit > 0:
token_data["rpm_limit"] = budget_rpm_limit
_response = UserAPIKeyAuth(**token_data)

View file

@ -195,10 +195,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
if tpm_limit is None or tpm_limit == 0: # Treat 0 as no limit
tpm_limit = sys.maxsize
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if rpm_limit is None:
if rpm_limit is None or rpm_limit == 0: # Treat 0 as no limit
rpm_limit = sys.maxsize
values_to_update_in_cache: List[
@ -310,9 +310,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if _model is not None:
if _tpm_limit_for_key_model:
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
if tpm_limit_for_model == 0: # Treat 0 as no limit
tpm_limit_for_model = sys.maxsize
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
if rpm_limit_for_model == 0: # Treat 0 as no limit
rpm_limit_for_model = sys.maxsize
new_val = await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
@ -350,9 +354,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if user_id is not None:
user_tpm_limit = user_api_key_dict.user_tpm_limit
user_rpm_limit = user_api_key_dict.user_rpm_limit
if user_tpm_limit is None:
if user_tpm_limit is None or user_tpm_limit == 0: # Treat 0 as no limit
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
if user_rpm_limit is None or user_rpm_limit == 0: # Treat 0 as no limit
user_rpm_limit = sys.maxsize
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
@ -378,9 +382,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
team_tpm_limit = user_api_key_dict.team_tpm_limit
team_rpm_limit = user_api_key_dict.team_rpm_limit
if team_tpm_limit is None:
if team_tpm_limit is None or team_tpm_limit == 0: # Treat 0 as no limit
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
if team_rpm_limit is None or team_rpm_limit == 0: # Treat 0 as no limit
team_rpm_limit = sys.maxsize
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
@ -409,9 +413,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
)
if end_user_tpm_limit is None:
if (
end_user_tpm_limit is None or end_user_tpm_limit == 0
): # Treat 0 as no limit
end_user_tpm_limit = sys.maxsize
if end_user_rpm_limit is None:
if (
end_user_rpm_limit is None or end_user_rpm_limit == 0
): # Treat 0 as no limit
end_user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks

View file

@ -491,12 +491,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
)
)
# Make a copy of the data *before* adding the proxy_server_request key
original_data_copy = copy.deepcopy(data) # Use deepcopy for nested structures
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": _headers,
"body": copy.copy(data), # use copy instead of deepcopy
"body": original_data_copy, # Use the deep copy without the circular reference
}
## Dynamic api version (Azure OpenAI endpoints) ##

View file

@ -3,6 +3,7 @@
import sys, os, asyncio, time, random, uuid
import traceback
from datetime import datetime, timedelta
from dotenv import load_dotenv
load_dotenv()
@ -21,6 +22,7 @@ from litellm.proxy._types import (
LiteLLM_BudgetTable,
LiteLLM_UserTable,
LiteLLM_TeamTable,
LiteLLM_VerificationTokenView,
)
from litellm.proxy.utils import PrismaClient
from litellm.proxy.auth.auth_checks import (
@ -29,6 +31,8 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import CallInfo
from pydantic import BaseModel
from unittest.mock import AsyncMock, MagicMock, patch
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
@ -323,7 +327,9 @@ async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
from fastapi import HTTPException
limiter = _PROXY_MaxBudgetLimiter()
mock_cache = DualCache() # The hook expects a cache object, even if not used in the updated logic
mock_cache = (
DualCache()
) # The hook expects a cache object, even if not used in the updated logic
# Ensure spend is a float, defaulting to 0.0 if None
actual_spend = token_spend if token_spend is not None else 0.0
@ -359,6 +365,112 @@ async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
pytest.fail(f"Unexpected exception type {type(e).__name__} raised: {e}")
@pytest.mark.asyncio
async def test_get_key_object_loads_budget_table_limits():
"""
Test if get_key_object correctly loads max_budget, tpm_limit, and rpm_limit
from the joined LiteLLM_BudgetTable when a budget_id is present on the key.
"""
from litellm.proxy.auth.auth_checks import get_key_object
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.caching.caching import DualCache
# Mock Prisma response simulating the joined view
mock_db_response_dict = {
"token": "hashed_test_token_budget_table",
"key_name": "sk-...test",
"key_alias": "test-budget-key",
"spend": 5.0,
"max_budget": None, # Budget on token table itself is None
"expires": None,
"models": [],
"aliases": {},
"config": {},
"user_id": "test-user-budget",
"team_id": None,
"max_parallel_requests": None,
"metadata": {},
"tpm_limit": None, # Limit on token table itself is None
"rpm_limit": None, # Limit on token table itself is None
"budget_duration": None,
"budget_reset_at": None,
"allowed_cache_controls": [],
"permissions": {},
"model_spend": {},
"model_max_budget": {},
"soft_budget_cooldown": False,
"blocked": False,
"org_id": None,
"created_at": datetime.now(),
"updated_at": datetime.now(),
"created_by": None,
"updated_by": None,
"team_spend": None,
"team_alias": None,
"team_tpm_limit": None,
"team_rpm_limit": None,
"team_max_budget": None,
"team_models": [],
"team_blocked": False,
"team_model_aliases": None,
"team_member_spend": None,
"team_member": None,
"team_metadata": None,
"budget_id": "budget_123", # Link to budget table
# Values coming from the joined LiteLLM_BudgetTable
"litellm_budget_table_max_budget": 20.0,
"litellm_budget_table_soft_budget": 15.0,
"litellm_budget_table_tpm_limit": 1000,
"litellm_budget_table_rpm_limit": 100,
"litellm_budget_table_model_max_budget": {"gpt-4": 5.0},
}
# Mock PrismaClient and its methods
mock_prisma_client = MagicMock(spec=PrismaClient)
# Create a mock object that mimics the structure returned by prisma.get_data for the raw query
mock_db_result = MagicMock()
for key, value in mock_db_response_dict.items():
setattr(mock_db_result, key, value)
# Add a model_dump method to the mock object
mock_db_result.model_dump = MagicMock(return_value=mock_db_response_dict)
# Mock the get_data method to return our simulated DB response object
mock_prisma_client.get_data = AsyncMock(return_value=mock_db_result)
mock_cache = DualCache()
mock_proxy_logging = MockProxyLogging() # Use the mock defined earlier
# Call get_key_object
user_auth_obj = await get_key_object(
hashed_token="hashed_test_token_budget_table",
prisma_client=mock_prisma_client,
user_api_key_cache=mock_cache,
proxy_logging_obj=mock_proxy_logging,
)
# Assertions
assert (
user_auth_obj.max_budget == 20.0
), "max_budget should be loaded from budget table"
assert (
user_auth_obj.soft_budget == 15.0
), "soft_budget should be loaded from budget table"
assert (
user_auth_obj.tpm_limit == 1000
), "tpm_limit should be loaded from budget table"
assert (
user_auth_obj.rpm_limit == 100
), "rpm_limit should be loaded from budget table"
assert user_auth_obj.model_max_budget == {
"gpt-4": 5.0
}, "model_max_budget should be loaded from budget table"
# Ensure original values from token table are not used if budget table values exist
assert user_auth_obj.spend == 5.0 # Spend comes from the token table itself
@pytest.mark.asyncio
async def test_is_valid_fallback_model():
from litellm.proxy.auth.auth_checks import is_valid_fallback_model
@ -607,7 +719,9 @@ async def test_virtual_key_max_budget_check_from_budget_table(
"max_budget": max_budget_from_table,
"soft_budget": None, # Assuming soft_budget is not the focus here
# Add other necessary fields if the model requires them
} if max_budget_from_table is not None else None,
}
if max_budget_from_table is not None
else None,
)
user_obj = LiteLLM_UserTable(
@ -651,7 +765,9 @@ async def test_virtual_key_max_budget_check_from_budget_table(
if max_budget_from_table is not None:
assert alert_called, "Budget alert should be triggered when max_budget is set"
else:
assert not alert_called, "Budget alert should not be triggered when max_budget is None"
assert (
not alert_called
), "Budget alert should not be triggered when max_budget is None"
@pytest.mark.asyncio