mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 3f8b827f79
into b82af5b826
This commit is contained in:
commit
7082f4f95e
5 changed files with 355 additions and 26 deletions
|
@ -1007,7 +1007,24 @@ async def get_key_object(
|
|||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
||||
token_data = _valid_token.model_dump(exclude_none=True)
|
||||
|
||||
# Manually map the budget from the joined budget table if it exists
|
||||
# This ensures the max_budget from LiteLLM_BudgetTable takes precedence if set
|
||||
if token_data.get("litellm_budget_table_max_budget") is not None:
|
||||
token_data["max_budget"] = token_data.pop("litellm_budget_table_max_budget")
|
||||
if token_data.get("litellm_budget_table_soft_budget") is not None:
|
||||
token_data["soft_budget"] = token_data.pop("litellm_budget_table_soft_budget")
|
||||
# Only override if budget table value is explicitly set (not None) and > 0
|
||||
budget_tpm_limit = token_data.pop("litellm_budget_table_tpm_limit", None)
|
||||
if budget_tpm_limit is not None and budget_tpm_limit > 0:
|
||||
token_data["tpm_limit"] = budget_tpm_limit
|
||||
budget_rpm_limit = token_data.pop("litellm_budget_table_rpm_limit", None)
|
||||
if budget_rpm_limit is not None and budget_rpm_limit > 0:
|
||||
token_data["rpm_limit"] = budget_rpm_limit
|
||||
|
||||
|
||||
_response = UserAPIKeyAuth(**token_data)
|
||||
|
||||
# save the key object to cache
|
||||
await _cache_key_object(
|
||||
|
|
|
@ -21,29 +21,37 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
|
|||
):
|
||||
try:
|
||||
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
|
||||
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
|
||||
user_row = await cache.async_get_cache(
|
||||
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
|
||||
)
|
||||
if user_row is None: # value not yet cached
|
||||
return
|
||||
max_budget = user_row["max_budget"]
|
||||
curr_spend = user_row["spend"]
|
||||
|
||||
# Use the budget information directly from the validated user_api_key_dict
|
||||
max_budget = user_api_key_dict.max_budget
|
||||
curr_spend = user_api_key_dict.spend
|
||||
|
||||
if max_budget is None:
|
||||
# No budget limit set for this key/user/team
|
||||
return
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
# If spend tracking hasn't started, assume 0
|
||||
curr_spend = 0.0
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
if curr_spend >= max_budget:
|
||||
verbose_proxy_logger.info(
|
||||
f"Budget Limit Reached for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}"
|
||||
)
|
||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Budget Check Passed for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
# Re-raise HTTPException to ensure FastAPI handles it correctly
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -195,10 +195,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"global_max_parallel_requests", None
|
||||
)
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
if tpm_limit is None or tpm_limit == 0: # Treat 0 as no limit
|
||||
tpm_limit = sys.maxsize
|
||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||
if rpm_limit is None:
|
||||
if rpm_limit is None or rpm_limit == 0: # Treat 0 as no limit
|
||||
rpm_limit = sys.maxsize
|
||||
|
||||
values_to_update_in_cache: List[
|
||||
|
@ -310,9 +310,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if _model is not None:
|
||||
if _tpm_limit_for_key_model:
|
||||
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||
if tpm_limit_for_model == 0: # Treat 0 as no limit
|
||||
tpm_limit_for_model = sys.maxsize
|
||||
|
||||
if _rpm_limit_for_key_model:
|
||||
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||
if rpm_limit_for_model == 0: # Treat 0 as no limit
|
||||
rpm_limit_for_model = sys.maxsize
|
||||
|
||||
new_val = await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
@ -350,9 +354,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if user_id is not None:
|
||||
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
||||
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
||||
if user_tpm_limit is None:
|
||||
if user_tpm_limit is None or user_tpm_limit == 0: # Treat 0 as no limit
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
if user_rpm_limit is None or user_rpm_limit == 0: # Treat 0 as no limit
|
||||
user_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||
|
@ -378,9 +382,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
||||
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
||||
|
||||
if team_tpm_limit is None:
|
||||
if team_tpm_limit is None or team_tpm_limit == 0: # Treat 0 as no limit
|
||||
team_tpm_limit = sys.maxsize
|
||||
if team_rpm_limit is None:
|
||||
if team_rpm_limit is None or team_rpm_limit == 0: # Treat 0 as no limit
|
||||
team_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||
|
@ -409,9 +413,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
|
||||
)
|
||||
|
||||
if end_user_tpm_limit is None:
|
||||
if (
|
||||
end_user_tpm_limit is None or end_user_tpm_limit == 0
|
||||
): # Treat 0 as no limit
|
||||
end_user_tpm_limit = sys.maxsize
|
||||
if end_user_rpm_limit is None:
|
||||
if (
|
||||
end_user_rpm_limit is None or end_user_rpm_limit == 0
|
||||
): # Treat 0 as no limit
|
||||
end_user_rpm_limit = sys.maxsize
|
||||
|
||||
# now do the same tpm/rpm checks
|
||||
|
|
|
@ -491,12 +491,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
|
||||
# Make a copy of the data *before* adding the proxy_server_request key
|
||||
original_data_copy = copy.deepcopy(data) # Use deepcopy for nested structures
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": _headers,
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
"body": original_data_copy, # Use the deep copy without the circular reference
|
||||
}
|
||||
|
||||
## Dynamic api version (Azure OpenAI endpoints) ##
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import sys, os, asyncio, time, random, uuid
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
@ -21,6 +22,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_BudgetTable,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_VerificationTokenView,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
|
@ -29,6 +31,8 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.proxy.utils import CallInfo
|
||||
from pydantic import BaseModel
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
||||
|
@ -255,7 +259,216 @@ async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_w
|
|||
llm_router=router,
|
||||
)
|
||||
|
||||
print(e)
|
||||
|
||||
# Mock ProxyLogging for budget alert testing
|
||||
class MockProxyLogging:
|
||||
def __init__(self):
|
||||
self.alert_triggered = False
|
||||
self.alert_type = None
|
||||
self.user_info = None
|
||||
|
||||
async def budget_alerts(self, type, user_info):
|
||||
self.alert_triggered = True
|
||||
self.alert_type = type
|
||||
self.user_info = user_info
|
||||
|
||||
# Add dummy methods for other required ProxyLogging methods if needed
|
||||
async def pre_call_hook(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def post_call_failure_hook(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def post_call_success_hook(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_hook(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def async_post_call_streaming_iterator_hook(self, response, *args, **kwargs):
|
||||
return response
|
||||
|
||||
def _init_response_taking_too_long_task(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def update_request_status(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def failed_tracking_alert(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def alerting_handler(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def failure_handler(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token_spend, max_budget, expect_error",
|
||||
[
|
||||
(5.0, 10.0, False), # Under budget
|
||||
(9.99, 10.0, False), # Just under budget
|
||||
(0.0, 0.0, True), # At zero budget
|
||||
(0.0, None, False), # No budget set
|
||||
(10.0, 10.0, True), # At budget limit
|
||||
(15.0, 10.0, True), # Over budget
|
||||
(None, 10.0, False), # Spend not tracked yet
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
|
||||
"""
|
||||
Test the _PROXY_MaxBudgetLimiter pre-call hook directly.
|
||||
This test verifies the fix applied in the hook itself.
|
||||
"""
|
||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from litellm.caching.caching import DualCache
|
||||
from fastapi import HTTPException
|
||||
|
||||
limiter = _PROXY_MaxBudgetLimiter()
|
||||
mock_cache = (
|
||||
DualCache()
|
||||
) # The hook expects a cache object, even if not used in the updated logic
|
||||
|
||||
# Ensure spend is a float, defaulting to 0.0 if None
|
||||
actual_spend = token_spend if token_spend is not None else 0.0
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
token="test-token-hook",
|
||||
spend=actual_spend,
|
||||
max_budget=max_budget,
|
||||
user_id="test-user-hook",
|
||||
)
|
||||
|
||||
mock_data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}
|
||||
|
||||
try:
|
||||
await limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=mock_cache,
|
||||
data=mock_data,
|
||||
call_type="completion",
|
||||
)
|
||||
if expect_error:
|
||||
pytest.fail(
|
||||
f"Expected HTTPException for spend={token_spend}, max_budget={max_budget}"
|
||||
)
|
||||
except HTTPException as e:
|
||||
if not expect_error:
|
||||
pytest.fail(
|
||||
f"Unexpected HTTPException for spend={token_spend}, max_budget={max_budget}: {e.detail}"
|
||||
)
|
||||
assert e.status_code == 429
|
||||
assert "Max budget limit reached" in e.detail
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected exception type {type(e).__name__} raised: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_object_loads_budget_table_limits():
|
||||
"""
|
||||
Test if get_key_object correctly loads max_budget, tpm_limit, and rpm_limit
|
||||
from the joined LiteLLM_BudgetTable when a budget_id is present on the key.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import get_key_object
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
# Mock Prisma response simulating the joined view
|
||||
mock_db_response_dict = {
|
||||
"token": "hashed_test_token_budget_table",
|
||||
"key_name": "sk-...test",
|
||||
"key_alias": "test-budget-key",
|
||||
"spend": 5.0,
|
||||
"max_budget": None, # Budget on token table itself is None
|
||||
"expires": None,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"user_id": "test-user-budget",
|
||||
"team_id": None,
|
||||
"max_parallel_requests": None,
|
||||
"metadata": {},
|
||||
"tpm_limit": None, # Limit on token table itself is None
|
||||
"rpm_limit": None, # Limit on token table itself is None
|
||||
"budget_duration": None,
|
||||
"budget_reset_at": None,
|
||||
"allowed_cache_controls": [],
|
||||
"permissions": {},
|
||||
"model_spend": {},
|
||||
"model_max_budget": {},
|
||||
"soft_budget_cooldown": False,
|
||||
"blocked": False,
|
||||
"org_id": None,
|
||||
"created_at": datetime.now(),
|
||||
"updated_at": datetime.now(),
|
||||
"created_by": None,
|
||||
"updated_by": None,
|
||||
"team_spend": None,
|
||||
"team_alias": None,
|
||||
"team_tpm_limit": None,
|
||||
"team_rpm_limit": None,
|
||||
"team_max_budget": None,
|
||||
"team_models": [],
|
||||
"team_blocked": False,
|
||||
"team_model_aliases": None,
|
||||
"team_member_spend": None,
|
||||
"team_member": None,
|
||||
"team_metadata": None,
|
||||
"budget_id": "budget_123", # Link to budget table
|
||||
# Values coming from the joined LiteLLM_BudgetTable
|
||||
"litellm_budget_table_max_budget": 20.0,
|
||||
"litellm_budget_table_soft_budget": 15.0,
|
||||
"litellm_budget_table_tpm_limit": 1000,
|
||||
"litellm_budget_table_rpm_limit": 100,
|
||||
"litellm_budget_table_model_max_budget": {"gpt-4": 5.0},
|
||||
}
|
||||
|
||||
# Mock PrismaClient and its methods
|
||||
mock_prisma_client = MagicMock(spec=PrismaClient)
|
||||
|
||||
# Create a mock object that mimics the structure returned by prisma.get_data for the raw query
|
||||
mock_db_result = MagicMock()
|
||||
for key, value in mock_db_response_dict.items():
|
||||
setattr(mock_db_result, key, value)
|
||||
|
||||
# Add a model_dump method to the mock object
|
||||
mock_db_result.model_dump = MagicMock(return_value=mock_db_response_dict)
|
||||
|
||||
# Mock the get_data method to return our simulated DB response object
|
||||
mock_prisma_client.get_data = AsyncMock(return_value=mock_db_result)
|
||||
|
||||
mock_cache = DualCache()
|
||||
mock_proxy_logging = MockProxyLogging() # Use the mock defined earlier
|
||||
|
||||
# Call get_key_object
|
||||
user_auth_obj = await get_key_object(
|
||||
hashed_token="hashed_test_token_budget_table",
|
||||
prisma_client=mock_prisma_client,
|
||||
user_api_key_cache=mock_cache,
|
||||
proxy_logging_obj=mock_proxy_logging,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert (
|
||||
user_auth_obj.max_budget == 20.0
|
||||
), "max_budget should be loaded from budget table"
|
||||
assert (
|
||||
user_auth_obj.soft_budget == 15.0
|
||||
), "soft_budget should be loaded from budget table"
|
||||
assert (
|
||||
user_auth_obj.tpm_limit == 1000
|
||||
), "tpm_limit should be loaded from budget table"
|
||||
assert (
|
||||
user_auth_obj.rpm_limit == 100
|
||||
), "rpm_limit should be loaded from budget table"
|
||||
assert user_auth_obj.model_max_budget == {
|
||||
"gpt-4": 5.0
|
||||
}, "model_max_budget should be loaded from budget table"
|
||||
# Ensure original values from token table are not used if budget table values exist
|
||||
assert user_auth_obj.spend == 5.0 # Spend comes from the token table itself
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -319,13 +532,11 @@ async def test_virtual_key_max_budget_check(
|
|||
|
||||
user_obj = LiteLLM_UserTable(
|
||||
user_id="test-user",
|
||||
user_email="test@email.com",
|
||||
user_email="test@example.com",
|
||||
max_budget=None,
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(
|
||||
user_api_key_cache=None,
|
||||
)
|
||||
proxy_logging_obj = MockProxyLogging()
|
||||
|
||||
# Track if budget alert was called
|
||||
alert_called = False
|
||||
|
@ -356,7 +567,6 @@ async def test_virtual_key_max_budget_check(
|
|||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Verify budget alert was triggered
|
||||
assert alert_called, "Budget alert should be triggered"
|
||||
|
||||
|
||||
|
@ -477,6 +687,89 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
|
|||
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token_spend, max_budget_from_table, expect_budget_error",
|
||||
[
|
||||
(5.0, 10.0, False), # Under budget
|
||||
(10.0, 10.0, True), # At budget limit
|
||||
(15.0, 10.0, True), # Over budget
|
||||
(5.0, None, False), # No budget set in table
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_virtual_key_max_budget_check_from_budget_table(
|
||||
token_spend, max_budget_from_table, expect_budget_error
|
||||
):
|
||||
"""
|
||||
Test if virtual key budget checks work when max_budget is derived
|
||||
from the joined LiteLLM_BudgetTable data.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
# Setup test data - Simulate data structure after get_key_object fix
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test-token-from-table",
|
||||
spend=token_spend,
|
||||
max_budget=max_budget_from_table, # This now reflects the budget from the table
|
||||
user_id="test-user-table",
|
||||
key_alias="test-key-table",
|
||||
# Simulate that litellm_budget_table was present during the join
|
||||
litellm_budget_table={
|
||||
"max_budget": max_budget_from_table,
|
||||
"soft_budget": None, # Assuming soft_budget is not the focus here
|
||||
# Add other necessary fields if the model requires them
|
||||
}
|
||||
if max_budget_from_table is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
user_obj = LiteLLM_UserTable(
|
||||
user_id="test-user-table",
|
||||
user_email="test-table@example.com",
|
||||
max_budget=None, # Ensure user-level budget doesn't interfere
|
||||
)
|
||||
|
||||
proxy_logging_obj = MockProxyLogging() # Use the mock class defined above
|
||||
|
||||
# Track if budget alert was called
|
||||
alert_called = False
|
||||
|
||||
async def mock_budget_alert(*args, **kwargs):
|
||||
nonlocal alert_called
|
||||
alert_called = True
|
||||
|
||||
proxy_logging_obj.budget_alerts = mock_budget_alert
|
||||
|
||||
try:
|
||||
await _virtual_key_max_budget_check(
|
||||
valid_token=valid_token,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_obj=user_obj,
|
||||
)
|
||||
if expect_budget_error:
|
||||
pytest.fail(
|
||||
f"Expected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}"
|
||||
)
|
||||
except litellm.BudgetExceededError as e:
|
||||
if not expect_budget_error:
|
||||
pytest.fail(
|
||||
f"Unexpected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}"
|
||||
)
|
||||
assert e.current_cost == token_spend
|
||||
assert e.max_budget == max_budget_from_table
|
||||
|
||||
await asyncio.sleep(0.1) # Allow time for alert task
|
||||
|
||||
# Verify budget alert was triggered only if there was a budget
|
||||
if max_budget_from_table is not None:
|
||||
assert alert_called, "Budget alert should be triggered when max_budget is set"
|
||||
else:
|
||||
assert (
|
||||
not alert_called
|
||||
), "Budget alert should not be triggered when max_budget is None"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_call_model():
|
||||
from litellm.proxy.auth.auth_checks import can_user_call_model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue