mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge 3f8b827f79
into b82af5b826
This commit is contained in:
commit
7082f4f95e
5 changed files with 355 additions and 26 deletions
|
@ -1007,7 +1007,24 @@ async def get_key_object(
|
||||||
code=status.HTTP_401_UNAUTHORIZED,
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
token_data = _valid_token.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
# Manually map the budget from the joined budget table if it exists
|
||||||
|
# This ensures the max_budget from LiteLLM_BudgetTable takes precedence if set
|
||||||
|
if token_data.get("litellm_budget_table_max_budget") is not None:
|
||||||
|
token_data["max_budget"] = token_data.pop("litellm_budget_table_max_budget")
|
||||||
|
if token_data.get("litellm_budget_table_soft_budget") is not None:
|
||||||
|
token_data["soft_budget"] = token_data.pop("litellm_budget_table_soft_budget")
|
||||||
|
# Only override if budget table value is explicitly set (not None) and > 0
|
||||||
|
budget_tpm_limit = token_data.pop("litellm_budget_table_tpm_limit", None)
|
||||||
|
if budget_tpm_limit is not None and budget_tpm_limit > 0:
|
||||||
|
token_data["tpm_limit"] = budget_tpm_limit
|
||||||
|
budget_rpm_limit = token_data.pop("litellm_budget_table_rpm_limit", None)
|
||||||
|
if budget_rpm_limit is not None and budget_rpm_limit > 0:
|
||||||
|
token_data["rpm_limit"] = budget_rpm_limit
|
||||||
|
|
||||||
|
|
||||||
|
_response = UserAPIKeyAuth(**token_data)
|
||||||
|
|
||||||
# save the key object to cache
|
# save the key object to cache
|
||||||
await _cache_key_object(
|
await _cache_key_object(
|
||||||
|
|
|
@ -21,29 +21,37 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
|
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
|
||||||
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
|
|
||||||
user_row = await cache.async_get_cache(
|
# Use the budget information directly from the validated user_api_key_dict
|
||||||
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
|
max_budget = user_api_key_dict.max_budget
|
||||||
)
|
curr_spend = user_api_key_dict.spend
|
||||||
if user_row is None: # value not yet cached
|
|
||||||
return
|
|
||||||
max_budget = user_row["max_budget"]
|
|
||||||
curr_spend = user_row["spend"]
|
|
||||||
|
|
||||||
if max_budget is None:
|
if max_budget is None:
|
||||||
|
# No budget limit set for this key/user/team
|
||||||
return
|
return
|
||||||
|
|
||||||
if curr_spend is None:
|
if curr_spend is None:
|
||||||
return
|
# If spend tracking hasn't started, assume 0
|
||||||
|
curr_spend = 0.0
|
||||||
|
|
||||||
# CHECK IF REQUEST ALLOWED
|
# CHECK IF REQUEST ALLOWED
|
||||||
if curr_spend >= max_budget:
|
if curr_spend >= max_budget:
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"Budget Limit Reached for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}"
|
||||||
|
)
|
||||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Budget Check Passed for {user_api_key_dict.user_id or user_api_key_dict.team_id or user_api_key_dict.api_key}. Current Spend: {curr_spend}, Max Budget: {max_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
|
# Re-raise HTTPException to ensure FastAPI handles it correctly
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
|
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occurred - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
pass
|
||||||
|
|
|
@ -195,10 +195,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
"global_max_parallel_requests", None
|
"global_max_parallel_requests", None
|
||||||
)
|
)
|
||||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||||
if tpm_limit is None:
|
if tpm_limit is None or tpm_limit == 0: # Treat 0 as no limit
|
||||||
tpm_limit = sys.maxsize
|
tpm_limit = sys.maxsize
|
||||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||||
if rpm_limit is None:
|
if rpm_limit is None or rpm_limit == 0: # Treat 0 as no limit
|
||||||
rpm_limit = sys.maxsize
|
rpm_limit = sys.maxsize
|
||||||
|
|
||||||
values_to_update_in_cache: List[
|
values_to_update_in_cache: List[
|
||||||
|
@ -310,9 +310,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
if _model is not None:
|
if _model is not None:
|
||||||
if _tpm_limit_for_key_model:
|
if _tpm_limit_for_key_model:
|
||||||
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||||
|
if tpm_limit_for_model == 0: # Treat 0 as no limit
|
||||||
|
tpm_limit_for_model = sys.maxsize
|
||||||
|
|
||||||
if _rpm_limit_for_key_model:
|
if _rpm_limit_for_key_model:
|
||||||
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||||
|
if rpm_limit_for_model == 0: # Treat 0 as no limit
|
||||||
|
rpm_limit_for_model = sys.maxsize
|
||||||
|
|
||||||
new_val = await self.check_key_in_limits(
|
new_val = await self.check_key_in_limits(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -350,9 +354,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
||||||
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
||||||
if user_tpm_limit is None:
|
if user_tpm_limit is None or user_tpm_limit == 0: # Treat 0 as no limit
|
||||||
user_tpm_limit = sys.maxsize
|
user_tpm_limit = sys.maxsize
|
||||||
if user_rpm_limit is None:
|
if user_rpm_limit is None or user_rpm_limit == 0: # Treat 0 as no limit
|
||||||
user_rpm_limit = sys.maxsize
|
user_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||||
|
@ -378,9 +382,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
||||||
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
||||||
|
|
||||||
if team_tpm_limit is None:
|
if team_tpm_limit is None or team_tpm_limit == 0: # Treat 0 as no limit
|
||||||
team_tpm_limit = sys.maxsize
|
team_tpm_limit = sys.maxsize
|
||||||
if team_rpm_limit is None:
|
if team_rpm_limit is None or team_rpm_limit == 0: # Treat 0 as no limit
|
||||||
team_rpm_limit = sys.maxsize
|
team_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||||
|
@ -409,9 +413,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
|
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
|
||||||
)
|
)
|
||||||
|
|
||||||
if end_user_tpm_limit is None:
|
if (
|
||||||
|
end_user_tpm_limit is None or end_user_tpm_limit == 0
|
||||||
|
): # Treat 0 as no limit
|
||||||
end_user_tpm_limit = sys.maxsize
|
end_user_tpm_limit = sys.maxsize
|
||||||
if end_user_rpm_limit is None:
|
if (
|
||||||
|
end_user_rpm_limit is None or end_user_rpm_limit == 0
|
||||||
|
): # Treat 0 as no limit
|
||||||
end_user_rpm_limit = sys.maxsize
|
end_user_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
# now do the same tpm/rpm checks
|
# now do the same tpm/rpm checks
|
||||||
|
|
|
@ -491,12 +491,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make a copy of the data *before* adding the proxy_server_request key
|
||||||
|
original_data_copy = copy.deepcopy(data) # Use deepcopy for nested structures
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = {
|
data["proxy_server_request"] = {
|
||||||
"url": str(request.url),
|
"url": str(request.url),
|
||||||
"method": request.method,
|
"method": request.method,
|
||||||
"headers": _headers,
|
"headers": _headers,
|
||||||
"body": copy.copy(data), # use copy instead of deepcopy
|
"body": original_data_copy, # Use the deep copy without the circular reference
|
||||||
}
|
}
|
||||||
|
|
||||||
## Dynamic api version (Azure OpenAI endpoints) ##
|
## Dynamic api version (Azure OpenAI endpoints) ##
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import sys, os, asyncio, time, random, uuid
|
import sys, os, asyncio, time, random, uuid
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -21,6 +22,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_BudgetTable,
|
LiteLLM_BudgetTable,
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
|
LiteLLM_VerificationTokenView,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
|
@ -29,6 +31,8 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import ProxyLogging
|
from litellm.proxy.utils import ProxyLogging
|
||||||
from litellm.proxy.utils import CallInfo
|
from litellm.proxy.utils import CallInfo
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
||||||
|
@ -255,7 +259,216 @@ async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_w
|
||||||
llm_router=router,
|
llm_router=router,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(e)
|
|
||||||
|
# Mock ProxyLogging for budget alert testing
|
||||||
|
class MockProxyLogging:
|
||||||
|
def __init__(self):
|
||||||
|
self.alert_triggered = False
|
||||||
|
self.alert_type = None
|
||||||
|
self.user_info = None
|
||||||
|
|
||||||
|
async def budget_alerts(self, type, user_info):
|
||||||
|
self.alert_triggered = True
|
||||||
|
self.alert_type = type
|
||||||
|
self.user_info = user_info
|
||||||
|
|
||||||
|
# Add dummy methods for other required ProxyLogging methods if needed
|
||||||
|
async def pre_call_hook(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def post_call_failure_hook(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def post_call_success_hook(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_streaming_hook(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def async_post_call_streaming_iterator_hook(self, response, *args, **kwargs):
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _init_response_taking_too_long_task(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update_request_status(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def failed_tracking_alert(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def alerting_handler(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def failure_handler(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"token_spend, max_budget, expect_error",
|
||||||
|
[
|
||||||
|
(5.0, 10.0, False), # Under budget
|
||||||
|
(9.99, 10.0, False), # Just under budget
|
||||||
|
(0.0, 0.0, True), # At zero budget
|
||||||
|
(0.0, None, False), # No budget set
|
||||||
|
(10.0, 10.0, True), # At budget limit
|
||||||
|
(15.0, 10.0, True), # Over budget
|
||||||
|
(None, 10.0, False), # Spend not tracked yet
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_budget_limiter_hook(token_spend, max_budget, expect_error):
|
||||||
|
"""
|
||||||
|
Test the _PROXY_MaxBudgetLimiter pre-call hook directly.
|
||||||
|
This test verifies the fix applied in the hook itself.
|
||||||
|
"""
|
||||||
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
limiter = _PROXY_MaxBudgetLimiter()
|
||||||
|
mock_cache = (
|
||||||
|
DualCache()
|
||||||
|
) # The hook expects a cache object, even if not used in the updated logic
|
||||||
|
|
||||||
|
# Ensure spend is a float, defaulting to 0.0 if None
|
||||||
|
actual_spend = token_spend if token_spend is not None else 0.0
|
||||||
|
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
token="test-token-hook",
|
||||||
|
spend=actual_spend,
|
||||||
|
max_budget=max_budget,
|
||||||
|
user_id="test-user-hook",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await limiter.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=mock_cache,
|
||||||
|
data=mock_data,
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
if expect_error:
|
||||||
|
pytest.fail(
|
||||||
|
f"Expected HTTPException for spend={token_spend}, max_budget={max_budget}"
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
if not expect_error:
|
||||||
|
pytest.fail(
|
||||||
|
f"Unexpected HTTPException for spend={token_spend}, max_budget={max_budget}: {e.detail}"
|
||||||
|
)
|
||||||
|
assert e.status_code == 429
|
||||||
|
assert "Max budget limit reached" in e.detail
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Unexpected exception type {type(e).__name__} raised: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_key_object_loads_budget_table_limits():
|
||||||
|
"""
|
||||||
|
Test if get_key_object correctly loads max_budget, tpm_limit, and rpm_limit
|
||||||
|
from the joined LiteLLM_BudgetTable when a budget_id is present on the key.
|
||||||
|
"""
|
||||||
|
from litellm.proxy.auth.auth_checks import get_key_object
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
|
||||||
|
# Mock Prisma response simulating the joined view
|
||||||
|
mock_db_response_dict = {
|
||||||
|
"token": "hashed_test_token_budget_table",
|
||||||
|
"key_name": "sk-...test",
|
||||||
|
"key_alias": "test-budget-key",
|
||||||
|
"spend": 5.0,
|
||||||
|
"max_budget": None, # Budget on token table itself is None
|
||||||
|
"expires": None,
|
||||||
|
"models": [],
|
||||||
|
"aliases": {},
|
||||||
|
"config": {},
|
||||||
|
"user_id": "test-user-budget",
|
||||||
|
"team_id": None,
|
||||||
|
"max_parallel_requests": None,
|
||||||
|
"metadata": {},
|
||||||
|
"tpm_limit": None, # Limit on token table itself is None
|
||||||
|
"rpm_limit": None, # Limit on token table itself is None
|
||||||
|
"budget_duration": None,
|
||||||
|
"budget_reset_at": None,
|
||||||
|
"allowed_cache_controls": [],
|
||||||
|
"permissions": {},
|
||||||
|
"model_spend": {},
|
||||||
|
"model_max_budget": {},
|
||||||
|
"soft_budget_cooldown": False,
|
||||||
|
"blocked": False,
|
||||||
|
"org_id": None,
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"updated_at": datetime.now(),
|
||||||
|
"created_by": None,
|
||||||
|
"updated_by": None,
|
||||||
|
"team_spend": None,
|
||||||
|
"team_alias": None,
|
||||||
|
"team_tpm_limit": None,
|
||||||
|
"team_rpm_limit": None,
|
||||||
|
"team_max_budget": None,
|
||||||
|
"team_models": [],
|
||||||
|
"team_blocked": False,
|
||||||
|
"team_model_aliases": None,
|
||||||
|
"team_member_spend": None,
|
||||||
|
"team_member": None,
|
||||||
|
"team_metadata": None,
|
||||||
|
"budget_id": "budget_123", # Link to budget table
|
||||||
|
# Values coming from the joined LiteLLM_BudgetTable
|
||||||
|
"litellm_budget_table_max_budget": 20.0,
|
||||||
|
"litellm_budget_table_soft_budget": 15.0,
|
||||||
|
"litellm_budget_table_tpm_limit": 1000,
|
||||||
|
"litellm_budget_table_rpm_limit": 100,
|
||||||
|
"litellm_budget_table_model_max_budget": {"gpt-4": 5.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock PrismaClient and its methods
|
||||||
|
mock_prisma_client = MagicMock(spec=PrismaClient)
|
||||||
|
|
||||||
|
# Create a mock object that mimics the structure returned by prisma.get_data for the raw query
|
||||||
|
mock_db_result = MagicMock()
|
||||||
|
for key, value in mock_db_response_dict.items():
|
||||||
|
setattr(mock_db_result, key, value)
|
||||||
|
|
||||||
|
# Add a model_dump method to the mock object
|
||||||
|
mock_db_result.model_dump = MagicMock(return_value=mock_db_response_dict)
|
||||||
|
|
||||||
|
# Mock the get_data method to return our simulated DB response object
|
||||||
|
mock_prisma_client.get_data = AsyncMock(return_value=mock_db_result)
|
||||||
|
|
||||||
|
mock_cache = DualCache()
|
||||||
|
mock_proxy_logging = MockProxyLogging() # Use the mock defined earlier
|
||||||
|
|
||||||
|
# Call get_key_object
|
||||||
|
user_auth_obj = await get_key_object(
|
||||||
|
hashed_token="hashed_test_token_budget_table",
|
||||||
|
prisma_client=mock_prisma_client,
|
||||||
|
user_api_key_cache=mock_cache,
|
||||||
|
proxy_logging_obj=mock_proxy_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert (
|
||||||
|
user_auth_obj.max_budget == 20.0
|
||||||
|
), "max_budget should be loaded from budget table"
|
||||||
|
assert (
|
||||||
|
user_auth_obj.soft_budget == 15.0
|
||||||
|
), "soft_budget should be loaded from budget table"
|
||||||
|
assert (
|
||||||
|
user_auth_obj.tpm_limit == 1000
|
||||||
|
), "tpm_limit should be loaded from budget table"
|
||||||
|
assert (
|
||||||
|
user_auth_obj.rpm_limit == 100
|
||||||
|
), "rpm_limit should be loaded from budget table"
|
||||||
|
assert user_auth_obj.model_max_budget == {
|
||||||
|
"gpt-4": 5.0
|
||||||
|
}, "model_max_budget should be loaded from budget table"
|
||||||
|
# Ensure original values from token table are not used if budget table values exist
|
||||||
|
assert user_auth_obj.spend == 5.0 # Spend comes from the token table itself
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -319,13 +532,11 @@ async def test_virtual_key_max_budget_check(
|
||||||
|
|
||||||
user_obj = LiteLLM_UserTable(
|
user_obj = LiteLLM_UserTable(
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
user_email="test@email.com",
|
user_email="test@example.com",
|
||||||
max_budget=None,
|
max_budget=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
proxy_logging_obj = ProxyLogging(
|
proxy_logging_obj = MockProxyLogging()
|
||||||
user_api_key_cache=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track if budget alert was called
|
# Track if budget alert was called
|
||||||
alert_called = False
|
alert_called = False
|
||||||
|
@ -356,7 +567,6 @@ async def test_virtual_key_max_budget_check(
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
# Verify budget alert was triggered
|
|
||||||
assert alert_called, "Budget alert should be triggered"
|
assert alert_called, "Budget alert should be triggered"
|
||||||
|
|
||||||
|
|
||||||
|
@ -477,6 +687,89 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
|
||||||
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"token_spend, max_budget_from_table, expect_budget_error",
|
||||||
|
[
|
||||||
|
(5.0, 10.0, False), # Under budget
|
||||||
|
(10.0, 10.0, True), # At budget limit
|
||||||
|
(15.0, 10.0, True), # Over budget
|
||||||
|
(5.0, None, False), # No budget set in table
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_virtual_key_max_budget_check_from_budget_table(
|
||||||
|
token_spend, max_budget_from_table, expect_budget_error
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test if virtual key budget checks work when max_budget is derived
|
||||||
|
from the joined LiteLLM_BudgetTable data.
|
||||||
|
"""
|
||||||
|
from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
|
# Setup test data - Simulate data structure after get_key_object fix
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token="test-token-from-table",
|
||||||
|
spend=token_spend,
|
||||||
|
max_budget=max_budget_from_table, # This now reflects the budget from the table
|
||||||
|
user_id="test-user-table",
|
||||||
|
key_alias="test-key-table",
|
||||||
|
# Simulate that litellm_budget_table was present during the join
|
||||||
|
litellm_budget_table={
|
||||||
|
"max_budget": max_budget_from_table,
|
||||||
|
"soft_budget": None, # Assuming soft_budget is not the focus here
|
||||||
|
# Add other necessary fields if the model requires them
|
||||||
|
}
|
||||||
|
if max_budget_from_table is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_obj = LiteLLM_UserTable(
|
||||||
|
user_id="test-user-table",
|
||||||
|
user_email="test-table@example.com",
|
||||||
|
max_budget=None, # Ensure user-level budget doesn't interfere
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_logging_obj = MockProxyLogging() # Use the mock class defined above
|
||||||
|
|
||||||
|
# Track if budget alert was called
|
||||||
|
alert_called = False
|
||||||
|
|
||||||
|
async def mock_budget_alert(*args, **kwargs):
|
||||||
|
nonlocal alert_called
|
||||||
|
alert_called = True
|
||||||
|
|
||||||
|
proxy_logging_obj.budget_alerts = mock_budget_alert
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _virtual_key_max_budget_check(
|
||||||
|
valid_token=valid_token,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_obj=user_obj,
|
||||||
|
)
|
||||||
|
if expect_budget_error:
|
||||||
|
pytest.fail(
|
||||||
|
f"Expected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}"
|
||||||
|
)
|
||||||
|
except litellm.BudgetExceededError as e:
|
||||||
|
if not expect_budget_error:
|
||||||
|
pytest.fail(
|
||||||
|
f"Unexpected BudgetExceededError for spend={token_spend}, max_budget_from_table={max_budget_from_table}"
|
||||||
|
)
|
||||||
|
assert e.current_cost == token_spend
|
||||||
|
assert e.max_budget == max_budget_from_table
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1) # Allow time for alert task
|
||||||
|
|
||||||
|
# Verify budget alert was triggered only if there was a budget
|
||||||
|
if max_budget_from_table is not None:
|
||||||
|
assert alert_called, "Budget alert should be triggered when max_budget is set"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not alert_called
|
||||||
|
), "Budget alert should not be triggered when max_budget is None"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_can_user_call_model():
|
async def test_can_user_call_model():
|
||||||
from litellm.proxy.auth.auth_checks import can_user_call_model
|
from litellm.proxy.auth.auth_checks import can_user_call_model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue