(Testing + Refactor) - Unit testing for team and virtual key budget checks (#7945)

* unit testing for test_virtual_key_max_budget_check

* refactor _team_max_budget_check

* is_model_allowed_by_pattern
This commit is contained in:
Ishaan Jaff 2025-01-23 16:58:16 -08:00 committed by GitHub
parent 8479d05d49
commit e0fb7eb4f7
3 changed files with 377 additions and 98 deletions

View file

@ -8,8 +8,8 @@ Run checks for:
2. If user is in budget 2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
import asyncio import asyncio
import re
import time import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@ -55,7 +55,7 @@ db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def common_checks( # noqa: PLR0915 async def common_checks(
request_body: dict, request_body: dict,
team_object: Optional[LiteLLM_TeamTable], team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable], user_object: Optional[LiteLLM_UserTable],
@ -64,6 +64,8 @@ def common_checks( # noqa: PLR0915
general_settings: dict, general_settings: dict,
route: str, route: str,
llm_router: Optional[Router], llm_router: Optional[Router],
proxy_logging_obj: ProxyLogging,
valid_token: Optional[UserAPIKeyAuth],
) -> bool: ) -> bool:
""" """
Common checks across jwt + key-based auth. Common checks across jwt + key-based auth.
@ -80,52 +82,27 @@ def common_checks( # noqa: PLR0915
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
""" """
_model = request_body.get("model", None) _model = request_body.get("model", None)
# 1. If team is blocked
if team_object is not None and team_object.blocked is True: if team_object is not None and team_object.blocked is True:
raise Exception( raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
) )
# 2. If team can call model # 2. If team can call model
if ( _team_model_access_check(
_model is not None team_object=team_object,
and team_object is not None model=_model,
and team_object.models is not None llm_router=llm_router,
and len(team_object.models) > 0 )
and _model not in team_object.models
):
# this means the team has access to all models on the proxy
if (
"all-proxy-models" in team_object.models
or "*" in team_object.models
or "openai/*" in team_object.models
):
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
elif (
model_in_access_group(
model=_model, team_models=team_object.models, llm_router=llm_router
)
is True
):
pass
elif _model and "*" in _model:
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
)
# 3. If team is in budget # 3. If team is in budget
if ( await _team_max_budget_check(
team_object is not None team_object=team_object,
and team_object.max_budget is not None proxy_logging_obj=proxy_logging_obj,
and team_object.spend is not None valid_token=valid_token,
and team_object.spend > team_object.max_budget )
):
raise litellm.BudgetExceededError(
current_cost=team_object.spend,
max_budget=team_object.max_budget,
message=f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}",
)
# 4. If user is in budget # 4. If user is in budget
## 4.1 check personal budget, if personal key ## 4.1 check personal budget, if personal key
if ( if (
@ -140,6 +117,7 @@ def common_checks( # noqa: PLR0915
max_budget=user_budget, max_budget=user_budget,
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}", message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
) )
## 4.2 check team member budget, if team key ## 4.2 check team member budget, if team key
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None: if end_user_object is not None and end_user_object.litellm_budget_table is not None:
@ -150,6 +128,7 @@ def common_checks( # noqa: PLR0915
max_budget=end_user_budget, max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}", message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
) )
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if ( if (
general_settings.get("enforce_user_param", None) is not None general_settings.get("enforce_user_param", None) is not None
@ -981,3 +960,144 @@ async def _virtual_key_max_budget_check(
current_cost=valid_token.spend, current_cost=valid_token.spend,
max_budget=valid_token.max_budget, max_budget=valid_token.max_budget,
) )
async def _virtual_key_soft_budget_check(
valid_token: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
):
"""
Triggers a budget alert if the token is over it's soft budget.
"""
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
verbose_proxy_logger.debug(
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
valid_token.token,
valid_token.spend,
valid_token.soft_budget,
)
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
user_email=None,
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
async def _team_max_budget_check(
team_object: Optional[LiteLLM_TeamTable],
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
):
"""
Check if the team is over it's max budget.
Raises:
BudgetExceededError if the team is over it's max budget.
Triggers a budget alert if the team is over it's max budget.
"""
if (
team_object is not None
and team_object.max_budget is not None
and team_object.spend is not None
and team_object.spend > team_object.max_budget
):
if valid_token:
call_info = CallInfo(
token=valid_token.token,
spend=team_object.spend,
max_budget=team_object.max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="team_budget",
user_info=call_info,
)
)
raise litellm.BudgetExceededError(
current_cost=team_object.spend,
max_budget=team_object.max_budget,
message=f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}",
)
def _team_model_access_check(
model: Optional[str],
team_object: Optional[LiteLLM_TeamTable],
llm_router: Optional[Router],
):
"""
Access check for team models
Raises:
Exception if the team is not allowed to call the`model`
"""
if (
model is not None
and team_object is not None
and team_object.models is not None
and len(team_object.models) > 0
and model not in team_object.models
):
# this means the team has access to all models on the proxy
if (
"all-proxy-models" in team_object.models
or "*" in team_object.models
or "openai/*" in team_object.models
):
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
elif (
model_in_access_group(
model=model, team_models=team_object.models, llm_router=llm_router
)
is True
):
pass
elif model and "*" in model:
pass
elif any(
is_model_allowed_by_pattern(model=model, allowed_model_pattern=team_model)
for team_model in team_object.models
):
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={model}. Allowed team models = {team_object.models}"
)
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
"""
Check if a model matches an allowed pattern.
Handles exact matches and wildcard patterns.
Args:
model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620")
allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*")
Returns:
bool: True if model matches the pattern, False otherwise
"""
if "*" in allowed_model_pattern:
pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
return bool(re.match(pattern, model))
return False

View file

@ -25,6 +25,7 @@ from litellm.proxy.auth.auth_checks import (
_cache_key_object, _cache_key_object,
_handle_failed_db_connection_for_get_key_object, _handle_failed_db_connection_for_get_key_object,
_virtual_key_max_budget_check, _virtual_key_max_budget_check,
_virtual_key_soft_budget_check,
allowed_routes_check, allowed_routes_check,
can_key_call_model, can_key_call_model,
common_checks, common_checks,
@ -613,7 +614,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
) )
# run through common checks # run through common checks
_ = common_checks( _ = await common_checks(
request_body=request_data, request_body=request_data,
team_object=team_object, team_object=team_object,
user_object=user_object, user_object=user_object,
@ -622,6 +623,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
llm_router=llm_router, llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=None,
) )
# return UserAPIKeyAuth object # return UserAPIKeyAuth object
@ -1099,30 +1102,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_obj=user_obj, user_obj=user_obj,
) )
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: # Check 5. Soft Budget Check
verbose_proxy_logger.debug( await _virtual_key_soft_budget_check(
"Crossed Soft Budget for token %s, spend %s, soft_budget %s", valid_token=valid_token,
valid_token.token, proxy_logging_obj=proxy_logging_obj,
valid_token.spend, )
valid_token.soft_budget,
)
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
user_email=None,
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
# Check 5. Token Model Spend is under Model budget # Check 5. Token Model Spend is under Model budget
max_budget_per_model = valid_token.model_max_budget max_budget_per_model = valid_token.model_max_budget
@ -1141,35 +1125,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_api_key_dict=valid_token, user_api_key_dict=valid_token,
model=current_model, model=current_model,
) )
# Check 6. Team spend is under Team budget
if (
hasattr(valid_token, "team_spend")
and valid_token.team_spend is not None
and hasattr(valid_token, "team_max_budget")
and valid_token.team_max_budget is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.team_spend,
max_budget=valid_token.team_max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="team_budget",
user_info=call_info,
)
)
if valid_token.team_spend >= valid_token.team_max_budget: # Check 6: Additional Common Checks across jwt + key auth
raise litellm.BudgetExceededError(
current_cost=valid_token.team_spend,
max_budget=valid_token.team_max_budget,
)
# Check 8: Additional Common Checks across jwt + key auth
if valid_token.team_id is not None: if valid_token.team_id is not None:
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable( _team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
team_id=valid_token.team_id, team_id=valid_token.team_id,
@ -1184,7 +1141,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
else: else:
_team_obj = None _team_obj = None
# Check 9: Check if key is a service account key # Check 7: Check if key is a service account key
await service_account_checks( await service_account_checks(
valid_token=valid_token, valid_token=valid_token,
request_data=request_data, request_data=request_data,
@ -1228,7 +1185,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_info=call_info, user_info=call_info,
) )
) )
_ = common_checks( _ = await common_checks(
request_body=request_data, request_body=request_data,
team_object=_team_obj, team_object=_team_obj,
user_object=user_obj, user_object=user_obj,
@ -1237,6 +1194,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
llm_router=llm_router, llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
) )
# Token passed all checks # Token passed all checks
if valid_token is None: if valid_token is None:

View file

@ -19,8 +19,19 @@ from litellm.proxy.auth.auth_checks import (
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import get_end_user_object from litellm.proxy.auth.auth_checks import get_end_user_object
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.proxy._types import LiteLLM_EndUserTable, LiteLLM_BudgetTable from litellm.proxy._types import (
LiteLLM_EndUserTable,
LiteLLM_BudgetTable,
LiteLLM_UserTable,
LiteLLM_TeamTable,
)
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.proxy.auth.auth_checks import (
_team_model_access_check,
_virtual_key_soft_budget_check,
)
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import CallInfo
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)]) @pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
@ -229,3 +240,192 @@ async def test_is_valid_fallback_model():
pytest.fail("Expected is_valid_fallback_model to fail") pytest.fail("Expected is_valid_fallback_model to fail")
except Exception as e: except Exception as e:
assert "Invalid" in str(e) assert "Invalid" in str(e)
@pytest.mark.parametrize(
"token_spend, max_budget, expect_budget_error",
[
(5.0, 10.0, False), # Under budget
(10.0, 10.0, True), # At budget limit
(15.0, 10.0, True), # Over budget
],
)
@pytest.mark.asyncio
async def test_virtual_key_max_budget_check(
token_spend, max_budget, expect_budget_error
):
"""
Test if virtual key budget checks work as expected:
1. Triggers budget alert for all cases
2. Raises BudgetExceededError when spend >= max_budget
"""
from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check
from litellm.proxy.utils import ProxyLogging
# Setup test data
valid_token = UserAPIKeyAuth(
token="test-token",
spend=token_spend,
max_budget=max_budget,
user_id="test-user",
key_alias="test-key",
)
user_obj = LiteLLM_UserTable(
user_id="test-user",
user_email="test@email.com",
max_budget=None,
)
proxy_logging_obj = ProxyLogging(
user_api_key_cache=None,
)
# Track if budget alert was called
alert_called = False
async def mock_budget_alert(*args, **kwargs):
nonlocal alert_called
alert_called = True
proxy_logging_obj.budget_alerts = mock_budget_alert
try:
await _virtual_key_max_budget_check(
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
user_obj=user_obj,
)
if expect_budget_error:
pytest.fail(
f"Expected BudgetExceededError for spend={token_spend}, max_budget={max_budget}"
)
except litellm.BudgetExceededError as e:
if not expect_budget_error:
pytest.fail(
f"Unexpected BudgetExceededError for spend={token_spend}, max_budget={max_budget}"
)
assert e.current_cost == token_spend
assert e.max_budget == max_budget
await asyncio.sleep(1)
# Verify budget alert was triggered
assert alert_called, "Budget alert should be triggered"
@pytest.mark.parametrize(
"model, team_models, expect_to_work",
[
("gpt-4", ["gpt-4"], True), # exact match
("gpt-4", ["all-proxy-models"], True), # all-proxy-models access
("gpt-4", ["*"], True), # wildcard access
("gpt-4", ["openai/*"], True), # openai wildcard access
(
"bedrock/anthropic.claude-3-5-sonnet-20240620",
["bedrock/*"],
True,
), # wildcard access
(
"bedrockz/anthropic.claude-3-5-sonnet-20240620",
["bedrock/*"],
False,
), # non-match wildcard access
("bedrock/very_new_model", ["bedrock/*"], True), # bedrock wildcard access
(
"bedrock/claude-3-5-sonnet-20240620",
["bedrock/claude-*"],
True,
), # match on pattern
(
"bedrock/claude-3-6-sonnet-20240620",
["bedrock/claude-3-5-*"],
False,
), # don't match on pattern
("openai/gpt-4o", ["openai/*"], True), # openai wildcard access
("gpt-4", ["gpt-3.5-turbo"], False), # model not in allowed list
("claude-3", [], True), # empty model list (allows all)
],
)
@pytest.mark.asyncio
async def test_team_model_access_check(model, team_models, expect_to_work):
"""
Test cases for _team_model_access_check:
1. Exact model match
2. all-proxy-models access
3. Wildcard (*) access
4. OpenAI wildcard access
5. Model not in allowed list
6. Empty model list
7. None model list
"""
team_object = LiteLLM_TeamTable(
team_id="test-team",
models=team_models,
)
try:
_team_model_access_check(
model=model,
team_object=team_object,
llm_router=None,
)
if not expect_to_work:
pytest.fail(
f"Expected model access check to fail for model={model}, team_models={team_models}"
)
except Exception as e:
if expect_to_work:
pytest.fail(
f"Expected model access check to work for model={model}, team_models={team_models}. Got error: {str(e)}"
)
@pytest.mark.parametrize(
"spend, soft_budget, expect_alert",
[
(100, 50, True), # Over soft budget
(50, 50, True), # At soft budget
(25, 50, False), # Under soft budget
(100, None, False), # No soft budget set
],
)
@pytest.mark.asyncio
async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
"""
Test cases for _virtual_key_soft_budget_check:
1. Spend over soft budget
2. Spend at soft budget
3. Spend under soft budget
4. No soft budget set
"""
alert_triggered = False
class MockProxyLogging:
async def budget_alerts(self, type, user_info):
nonlocal alert_triggered
alert_triggered = True
assert type == "soft_budget"
assert isinstance(user_info, CallInfo)
valid_token = UserAPIKeyAuth(
token="test-token",
spend=spend,
soft_budget=soft_budget,
user_id="test-user",
team_id="test-team",
key_alias="test-key",
)
proxy_logging_obj = MockProxyLogging()
await _virtual_key_soft_budget_check(
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(0.1) # Allow time for the alert task to complete
assert (
alert_triggered == expect_alert
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"