mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(budgets): fix rejecting requests when user over limits
This commit is contained in:
parent
ad4aca02b2
commit
bffeeb4f2e
4 changed files with 157 additions and 22 deletions
|
@ -1007,6 +1007,14 @@ async def get_key_object(
|
|||
token_data["max_budget"] = token_data.pop("litellm_budget_table_max_budget")
|
||||
if token_data.get("litellm_budget_table_soft_budget") is not None:
|
||||
token_data["soft_budget"] = token_data.pop("litellm_budget_table_soft_budget")
|
||||
# Only override if budget table value is explicitly set (not None) and > 0
|
||||
budget_tpm_limit = token_data.pop("litellm_budget_table_tpm_limit", None)
|
||||
if budget_tpm_limit is not None and budget_tpm_limit > 0:
|
||||
token_data["tpm_limit"] = budget_tpm_limit
|
||||
budget_rpm_limit = token_data.pop("litellm_budget_table_rpm_limit", None)
|
||||
if budget_rpm_limit is not None and budget_rpm_limit > 0:
|
||||
token_data["rpm_limit"] = budget_rpm_limit
|
||||
|
||||
|
||||
_response = UserAPIKeyAuth(**token_data)
|
||||
|
||||
|
|
|
@ -195,10 +195,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"global_max_parallel_requests", None
|
||||
)
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
if tpm_limit is None or tpm_limit == 0: # Treat 0 as no limit
|
||||
tpm_limit = sys.maxsize
|
||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||
if rpm_limit is None:
|
||||
if rpm_limit is None or rpm_limit == 0: # Treat 0 as no limit
|
||||
rpm_limit = sys.maxsize
|
||||
|
||||
values_to_update_in_cache: List[
|
||||
|
@ -310,9 +310,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if _model is not None:
|
||||
if _tpm_limit_for_key_model:
|
||||
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||
if tpm_limit_for_model == 0: # Treat 0 as no limit
|
||||
tpm_limit_for_model = sys.maxsize
|
||||
|
||||
if _rpm_limit_for_key_model:
|
||||
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||
if rpm_limit_for_model == 0: # Treat 0 as no limit
|
||||
rpm_limit_for_model = sys.maxsize
|
||||
|
||||
new_val = await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
@ -350,9 +354,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if user_id is not None:
|
||||
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
||||
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
||||
if user_tpm_limit is None:
|
||||
if user_tpm_limit is None or user_tpm_limit == 0: # Treat 0 as no limit
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
if user_rpm_limit is None or user_rpm_limit == 0: # Treat 0 as no limit
|
||||
user_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||
|
@ -378,9 +382,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
||||
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
||||
|
||||
if team_tpm_limit is None:
|
||||
if team_tpm_limit is None or team_tpm_limit == 0: # Treat 0 as no limit
|
||||
team_tpm_limit = sys.maxsize
|
||||
if team_rpm_limit is None:
|
||||
if team_rpm_limit is None or team_rpm_limit == 0: # Treat 0 as no limit
|
||||
team_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||
|
@ -409,9 +413,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
|
||||
)
|
||||
|
||||
if end_user_tpm_limit is None:
|
||||
if (
|
||||
end_user_tpm_limit is None or end_user_tpm_limit == 0
|
||||
): # Treat 0 as no limit
|
||||
end_user_tpm_limit = sys.maxsize
|
||||
if end_user_rpm_limit is None:
|
||||
if (
|
||||
end_user_rpm_limit is None or end_user_rpm_limit == 0
|
||||
): # Treat 0 as no limit
|
||||
end_user_rpm_limit = sys.maxsize
|
||||
|
||||
# now do the same tpm/rpm checks
|
||||
|
|
|
@ -491,12 +491,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
|
||||
# Make a copy of the data *before* adding the proxy_server_request key
|
||||
original_data_copy = copy.deepcopy(data) # Use deepcopy for nested structures
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": _headers,
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
"body": original_data_copy, # Use the deep copy without the circular reference
|
||||
}
|
||||
|
||||
## Dynamic api version (Azure OpenAI endpoints) ##
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import sys, os, asyncio, time, random, uuid
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
@ -21,6 +22,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_BudgetTable,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_VerificationTokenView,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
|
@ -29,6 +31,8 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.proxy.utils import CallInfo
|
||||
from pydantic import BaseModel
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
||||
|
@ -323,7 +327,9 @@ async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
|
|||
from fastapi import HTTPException
|
||||
|
||||
limiter = _PROXY_MaxBudgetLimiter()
|
||||
mock_cache = DualCache() # The hook expects a cache object, even if not used in the updated logic
|
||||
mock_cache = (
|
||||
DualCache()
|
||||
) # The hook expects a cache object, even if not used in the updated logic
|
||||
|
||||
# Ensure spend is a float, defaulting to 0.0 if None
|
||||
actual_spend = token_spend if token_spend is not None else 0.0
|
||||
|
@ -359,6 +365,112 @@ async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
|
|||
pytest.fail(f"Unexpected exception type {type(e).__name__} raised: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_object_loads_budget_table_limits():
|
||||
"""
|
||||
Test if get_key_object correctly loads max_budget, tpm_limit, and rpm_limit
|
||||
from the joined LiteLLM_BudgetTable when a budget_id is present on the key.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import get_key_object
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
# Mock Prisma response simulating the joined view
|
||||
mock_db_response_dict = {
|
||||
"token": "hashed_test_token_budget_table",
|
||||
"key_name": "sk-...test",
|
||||
"key_alias": "test-budget-key",
|
||||
"spend": 5.0,
|
||||
"max_budget": None, # Budget on token table itself is None
|
||||
"expires": None,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"user_id": "test-user-budget",
|
||||
"team_id": None,
|
||||
"max_parallel_requests": None,
|
||||
"metadata": {},
|
||||
"tpm_limit": None, # Limit on token table itself is None
|
||||
"rpm_limit": None, # Limit on token table itself is None
|
||||
"budget_duration": None,
|
||||
"budget_reset_at": None,
|
||||
"allowed_cache_controls": [],
|
||||
"permissions": {},
|
||||
"model_spend": {},
|
||||
"model_max_budget": {},
|
||||
"soft_budget_cooldown": False,
|
||||
"blocked": False,
|
||||
"org_id": None,
|
||||
"created_at": datetime.now(),
|
||||
"updated_at": datetime.now(),
|
||||
"created_by": None,
|
||||
"updated_by": None,
|
||||
"team_spend": None,
|
||||
"team_alias": None,
|
||||
"team_tpm_limit": None,
|
||||
"team_rpm_limit": None,
|
||||
"team_max_budget": None,
|
||||
"team_models": [],
|
||||
"team_blocked": False,
|
||||
"team_model_aliases": None,
|
||||
"team_member_spend": None,
|
||||
"team_member": None,
|
||||
"team_metadata": None,
|
||||
"budget_id": "budget_123", # Link to budget table
|
||||
# Values coming from the joined LiteLLM_BudgetTable
|
||||
"litellm_budget_table_max_budget": 20.0,
|
||||
"litellm_budget_table_soft_budget": 15.0,
|
||||
"litellm_budget_table_tpm_limit": 1000,
|
||||
"litellm_budget_table_rpm_limit": 100,
|
||||
"litellm_budget_table_model_max_budget": {"gpt-4": 5.0},
|
||||
}
|
||||
|
||||
# Mock PrismaClient and its methods
|
||||
mock_prisma_client = MagicMock(spec=PrismaClient)
|
||||
|
||||
# Create a mock object that mimics the structure returned by prisma.get_data for the raw query
|
||||
mock_db_result = MagicMock()
|
||||
for key, value in mock_db_response_dict.items():
|
||||
setattr(mock_db_result, key, value)
|
||||
|
||||
# Add a model_dump method to the mock object
|
||||
mock_db_result.model_dump = MagicMock(return_value=mock_db_response_dict)
|
||||
|
||||
# Mock the get_data method to return our simulated DB response object
|
||||
mock_prisma_client.get_data = AsyncMock(return_value=mock_db_result)
|
||||
|
||||
mock_cache = DualCache()
|
||||
mock_proxy_logging = MockProxyLogging() # Use the mock defined earlier
|
||||
|
||||
# Call get_key_object
|
||||
user_auth_obj = await get_key_object(
|
||||
hashed_token="hashed_test_token_budget_table",
|
||||
prisma_client=mock_prisma_client,
|
||||
user_api_key_cache=mock_cache,
|
||||
proxy_logging_obj=mock_proxy_logging,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert (
|
||||
user_auth_obj.max_budget == 20.0
|
||||
), "max_budget should be loaded from budget table"
|
||||
assert (
|
||||
user_auth_obj.soft_budget == 15.0
|
||||
), "soft_budget should be loaded from budget table"
|
||||
assert (
|
||||
user_auth_obj.tpm_limit == 1000
|
||||
), "tpm_limit should be loaded from budget table"
|
||||
assert (
|
||||
user_auth_obj.rpm_limit == 100
|
||||
), "rpm_limit should be loaded from budget table"
|
||||
assert user_auth_obj.model_max_budget == {
|
||||
"gpt-4": 5.0
|
||||
}, "model_max_budget should be loaded from budget table"
|
||||
# Ensure original values from token table are not used if budget table values exist
|
||||
assert user_auth_obj.spend == 5.0 # Spend comes from the token table itself
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_valid_fallback_model():
|
||||
from litellm.proxy.auth.auth_checks import is_valid_fallback_model
|
||||
|
@ -607,7 +719,9 @@ async def test_virtual_key_max_budget_check_from_budget_table(
|
|||
"max_budget": max_budget_from_table,
|
||||
"soft_budget": None, # Assuming soft_budget is not the focus here
|
||||
# Add other necessary fields if the model requires them
|
||||
} if max_budget_from_table is not None else None,
|
||||
}
|
||||
if max_budget_from_table is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
user_obj = LiteLLM_UserTable(
|
||||
|
@ -651,7 +765,9 @@ async def test_virtual_key_max_budget_check_from_budget_table(
|
|||
if max_budget_from_table is not None:
|
||||
assert alert_called, "Budget alert should be triggered when max_budget is set"
|
||||
else:
|
||||
assert not alert_called, "Budget alert should not be triggered when max_budget is None"
|
||||
assert (
|
||||
not alert_called
|
||||
), "Budget alert should not be triggered when max_budget is None"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue