This commit is contained in:
Sam 2025-04-24 00:56:37 -07:00 committed by GitHub
commit 7082f4f95e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 355 additions and 26 deletions

View file

@ -1007,7 +1007,24 @@ async def get_key_object(
code=status.HTTP_401_UNAUTHORIZED,
)
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
token_data = _valid_token.model_dump(exclude_none=True)
# Manually map the budget from the joined budget table if it exists
# This ensures the max_budget from LiteLLM_BudgetTable takes precedence if set
if token_data.get("litellm_budget_table_max_budget") is not None:
token_data["max_budget"] = token_data.pop("litellm_budget_table_max_budget")
if token_data.get("litellm_budget_table_soft_budget") is not None:
token_data["soft_budget"] = token_data.pop("litellm_budget_table_soft_budget")
# Only override if budget table value is explicitly set (not None) and > 0
budget_tpm_limit = token_data.pop("litellm_budget_table_tpm_limit", None)
if budget_tpm_limit is not None and budget_tpm_limit > 0:
token_data["tpm_limit"] = budget_tpm_limit
budget_rpm_limit = token_data.pop("litellm_budget_table_rpm_limit", None)
if budget_rpm_limit is not None and budget_rpm_limit > 0:
token_data["rpm_limit"] = budget_rpm_limit
_response = UserAPIKeyAuth(**token_data)
# save the key object to cache
await _cache_key_object(

View file

@ -21,29 +21,37 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
):
try:
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = await cache.async_get_cache(
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
)
if user_row is None: # value not yet cached
return
max_budget = user_row["max_budget"]
curr_spend = user_row["spend"]
# Use the budget information directly from the validated user_api_key_dict
max_budget = user_api_key_dict.max_budget
curr_spend = user_api_key_dict.spend
if max_budget is None:
# No budget limit set for this key/user/team
return
if curr_spend is None:
return
# If spend tracking hasn't started, assume 0
curr_spend = 0.0
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
verbose_proxy_logger.info(
f"Budget Limit Reached for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}"
)
raise HTTPException(status_code=429, detail="Max budget limit reached.")
else:
verbose_proxy_logger.debug(
f"Budget Check Passed for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}"
)
except HTTPException as e:
# Re-raise HTTPException to ensure FastAPI handles it correctly
raise e
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occurred - {}".format(
str(e)
)
)
pass

View file

@ -195,10 +195,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
if tpm_limit is None or tpm_limit == 0: # Treat 0 as no limit
tpm_limit = sys.maxsize
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if rpm_limit is None:
if rpm_limit is None or rpm_limit == 0: # Treat 0 as no limit
rpm_limit = sys.maxsize
values_to_update_in_cache: List[
@ -310,9 +310,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if _model is not None:
if _tpm_limit_for_key_model:
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
if tpm_limit_for_model == 0: # Treat 0 as no limit
tpm_limit_for_model = sys.maxsize
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
if rpm_limit_for_model == 0: # Treat 0 as no limit
rpm_limit_for_model = sys.maxsize
new_val = await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
@ -350,9 +354,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if user_id is not None:
user_tpm_limit = user_api_key_dict.user_tpm_limit
user_rpm_limit = user_api_key_dict.user_rpm_limit
if user_tpm_limit is None:
if user_tpm_limit is None or user_tpm_limit == 0: # Treat 0 as no limit
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
if user_rpm_limit is None or user_rpm_limit == 0: # Treat 0 as no limit
user_rpm_limit = sys.maxsize
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
@ -378,9 +382,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
team_tpm_limit = user_api_key_dict.team_tpm_limit
team_rpm_limit = user_api_key_dict.team_rpm_limit
if team_tpm_limit is None:
if team_tpm_limit is None or team_tpm_limit == 0: # Treat 0 as no limit
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
if team_rpm_limit is None or team_rpm_limit == 0: # Treat 0 as no limit
team_rpm_limit = sys.maxsize
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
@ -409,9 +413,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
)
if end_user_tpm_limit is None:
if (
end_user_tpm_limit is None or end_user_tpm_limit == 0
): # Treat 0 as no limit
end_user_tpm_limit = sys.maxsize
if end_user_rpm_limit is None:
if (
end_user_rpm_limit is None or end_user_rpm_limit == 0
): # Treat 0 as no limit
end_user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks

View file

@ -491,12 +491,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
)
)
# Make a copy of the data *before* adding the proxy_server_request key
original_data_copy = copy.deepcopy(data) # Use deepcopy for nested structures
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": _headers,
"body": copy.copy(data), # use copy instead of deepcopy
"body": original_data_copy, # Use the deep copy without the circular reference
}
## Dynamic api version (Azure OpenAI endpoints) ##

View file

@ -3,6 +3,7 @@
import sys, os, asyncio, time, random, uuid
import traceback
from datetime import datetime, timedelta
from dotenv import load_dotenv
load_dotenv()
@ -21,6 +22,7 @@ from litellm.proxy._types import (
LiteLLM_BudgetTable,
LiteLLM_UserTable,
LiteLLM_TeamTable,
LiteLLM_VerificationTokenView,
)
from litellm.proxy.utils import PrismaClient
from litellm.proxy.auth.auth_checks import (
@ -29,6 +31,8 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import CallInfo
from pydantic import BaseModel
from unittest.mock import AsyncMock, MagicMock, patch
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
@ -255,7 +259,216 @@ async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_w
llm_router=router,
)
print(e)
# Mock ProxyLogging for budget alert testing
class MockProxyLogging:
def __init__(self):
self.alert_triggered = False
self.alert_type = None
self.user_info = None
async def budget_alerts(self, type, user_info):
self.alert_triggered = True
self.alert_type = type
self.user_info = user_info
# Add dummy methods for other required ProxyLogging methods if needed
async def pre_call_hook(self, *args, **kwargs):
pass
async def post_call_failure_hook(self, *args, **kwargs):
pass
async def post_call_success_hook(self, *args, **kwargs):
pass
async def async_post_call_streaming_hook(self, *args, **kwargs):
pass
def async_post_call_streaming_iterator_hook(self, response, *args, **kwargs):
return response
def _init_response_taking_too_long_task(self, *args, **kwargs):
pass
async def update_request_status(self, *args, **kwargs):
pass
async def failed_tracking_alert(self, *args, **kwargs):
pass
async def alerting_handler(self, *args, **kwargs):
pass
async def failure_handler(self, *args, **kwargs):
pass
@pytest.mark.parametrize(
"token_spend, max_budget, expect_error",
[
(5.0, 10.0, False), # Under budget
(9.99, 10.0, False), # Just under budget
(0.0, 0.0, True), # At zero budget
(0.0, None, False), # No budget set
(10.0, 10.0, True), # At budget limit
(15.0, 10.0, True), # Over budget
(None, 10.0, False), # Spend not tracked yet
],
)
@pytest.mark.asyncio
async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
"""
Test the _PROXY_MaxBudgetLimiter pre-call hook directly.
This test verifies the fix applied in the hook itself.
"""
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.caching.caching import DualCache
from fastapi import HTTPException
limiter = _PROXY_MaxBudgetLimiter()
mock_cache = (
DualCache()
) # The hook expects a cache object, even if not used in the updated logic
# Ensure spend is a float, defaulting to 0.0 if None
actual_spend = token_spend if token_spend is not None else 0.0
user_api_key_dict = UserAPIKeyAuth(
token="test-token-hook",
spend=actual_spend,
max_budget=max_budget,
user_id="test-user-hook",
)
mock_data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}
try:
await limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=mock_cache,
data=mock_data,
call_type="completion",
)
if expect_error:
pytest.fail(
f"Expected HTTPException for spend={token_spend}, max_budget={max_budget}"
)
except HTTPException as e:
if not expect_error:
pytest.fail(
f"Unexpected HTTPException for spend={token_spend}, max_budget={max_budget}: {e.detail}"
)
assert e.status_code == 429
assert "Max budget limit reached" in e.detail
except Exception as e:
pytest.fail(f"Unexpected exception type {type(e).__name__} raised: {e}")
@pytest.mark.asyncio
async def test_get_key_object_loads_budget_table_limits():
"""
Test if get_key_object correctly loads max_budget, tpm_limit, and rpm_limit
from the joined LiteLLM_BudgetTable when a budget_id is present on the key.
"""
from litellm.proxy.auth.auth_checks import get_key_object
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.caching.caching import DualCache
# Mock Prisma response simulating the joined view
mock_db_response_dict = {
"token": "hashed_test_token_budget_table",
"key_name": "sk-...test",
"key_alias": "test-budget-key",
"spend": 5.0,
"max_budget": None, # Budget on token table itself is None
"expires": None,
"models": [],
"aliases": {},
"config": {},
"user_id": "test-user-budget",
"team_id": None,
"max_parallel_requests": None,
"metadata": {},
"tpm_limit": None, # Limit on token table itself is None
"rpm_limit": None, # Limit on token table itself is None
"budget_duration": None,
"budget_reset_at": None,
"allowed_cache_controls": [],
"permissions": {},
"model_spend": {},
"model_max_budget": {},
"soft_budget_cooldown": False,
"blocked": False,
"org_id": None,
"created_at": datetime.now(),
"updated_at": datetime.now(),
"created_by": None,
"updated_by": None,
"team_spend": None,
"team_alias": None,
"team_tpm_limit": None,
"team_rpm_limit": None,
"team_max_budget": None,
"team_models": [],
"team_blocked": False,
"team_model_aliases": None,
"team_member_spend": None,
"team_member": None,
"team_metadata": None,
"budget_id": "budget_123", # Link to budget table
# Values coming from the joined LiteLLM_BudgetTable
"litellm_budget_table_max_budget": 20.0,
"litellm_budget_table_soft_budget": 15.0,
"litellm_budget_table_tpm_limit": 1000,
"litellm_budget_table_rpm_limit": 100,
"litellm_budget_table_model_max_budget": {"gpt-4": 5.0},
}
# Mock PrismaClient and its methods
mock_prisma_client = MagicMock(spec=PrismaClient)
# Create a mock object that mimics the structure returned by prisma.get_data for the raw query
mock_db_result = MagicMock()
for key, value in mock_db_response_dict.items():
setattr(mock_db_result, key, value)
# Add a model_dump method to the mock object
mock_db_result.model_dump = MagicMock(return_value=mock_db_response_dict)
# Mock the get_data method to return our simulated DB response object
mock_prisma_client.get_data = AsyncMock(return_value=mock_db_result)
mock_cache = DualCache()
mock_proxy_logging = MockProxyLogging() # Use the mock defined earlier
# Call get_key_object
user_auth_obj = await get_key_object(
hashed_token="hashed_test_token_budget_table",
prisma_client=mock_prisma_client,
user_api_key_cache=mock_cache,
proxy_logging_obj=mock_proxy_logging,
)
# Assertions
assert (
user_auth_obj.max_budget == 20.0
), "max_budget should be loaded from budget table"
assert (
user_auth_obj.soft_budget == 15.0
), "soft_budget should be loaded from budget table"
assert (
user_auth_obj.tpm_limit == 1000
), "tpm_limit should be loaded from budget table"
assert (
user_auth_obj.rpm_limit == 100
), "rpm_limit should be loaded from budget table"
assert user_auth_obj.model_max_budget == {
"gpt-4": 5.0
}, "model_max_budget should be loaded from budget table"
# Ensure original values from token table are not used if budget table values exist
assert user_auth_obj.spend == 5.0 # Spend comes from the token table itself
@pytest.mark.asyncio
@ -319,13 +532,11 @@ async def test_virtual_key_max_budget_check(
user_obj = LiteLLM_UserTable(
user_id="test-user",
user_email="test@email.com",
user_email="test@example.com",
max_budget=None,
)
proxy_logging_obj = ProxyLogging(
user_api_key_cache=None,
)
proxy_logging_obj = MockProxyLogging()
# Track if budget alert was called
alert_called = False
@ -356,7 +567,6 @@ async def test_virtual_key_max_budget_check(
await asyncio.sleep(1)
# Verify budget alert was triggered
assert alert_called, "Budget alert should be triggered"
@ -477,6 +687,89 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
@pytest.mark.parametrize(
"token_spend, max_budget_from_table, expect_budget_error",
[
(5.0, 10.0, False), # Under budget
(10.0, 10.0, True), # At budget limit
(15.0, 10.0, True), # Over budget
(5.0, None, False), # No budget set in table
],
)
@pytest.mark.asyncio
async def test_virtual_key_max_budget_check_from_budget_table(
token_spend, max_budget_from_table, expect_budget_error
):
"""
Test if virtual key budget checks work when max_budget is derived
from the joined LiteLLM_BudgetTable data.
"""
from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check
from litellm.proxy.utils import ProxyLogging
# Setup test data - Simulate data structure after get_key_object fix
valid_token = UserAPIKeyAuth(
token="test-token-from-table",
spend=token_spend,
max_budget=max_budget_from_table, # This now reflects the budget from the table
user_id="test-user-table",
key_alias="test-key-table",
# Simulate that litellm_budget_table was present during the join
litellm_budget_table={
"max_budget": max_budget_from_table,
"soft_budget": None, # Assuming soft_budget is not the focus here
# Add other necessary fields if the model requires them
}
if max_budget_from_table is not None
else None,
)
user_obj = LiteLLM_UserTable(
user_id="test-user-table",
user_email="test-table@example.com",
max_budget=None, # Ensure user-level budget doesn't interfere
)
proxy_logging_obj = MockProxyLogging() # Use the mock class defined above
# Track if budget alert was called
alert_called = False
async def mock_budget_alert(*args, **kwargs):
nonlocal alert_called
alert_called = True
proxy_logging_obj.budget_alerts = mock_budget_alert
try:
await _virtual_key_max_budget_check(
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
user_obj=user_obj,
)
if expect_budget_error:
pytest.fail(
f"Expected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}"
)
except litellm.BudgetExceededError as e:
if not expect_budget_error:
pytest.fail(
f"Unexpected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}"
)
assert e.current_cost == token_spend
assert e.max_budget == max_budget_from_table
await asyncio.sleep(0.1) # Allow time for alert task
# Verify budget alert was triggered only if there was a budget
if max_budget_from_table is not None:
assert alert_called, "Budget alert should be triggered when max_budget is set"
else:
assert (
not alert_called
), "Budget alert should not be triggered when max_budget is None"
@pytest.mark.asyncio
async def test_can_user_call_model():
from litellm.proxy.auth.auth_checks import can_user_call_model