mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(Testing + Refactor) - Unit testing for team and virtual key budget checks (#7945)
* unit testing for test_virtual_key_max_budget_check * refactor _team_max_budget_check * is_model_allowed_by_pattern
This commit is contained in:
parent
8479d05d49
commit
e0fb7eb4f7
3 changed files with 377 additions and 98 deletions
|
@ -8,8 +8,8 @@ Run checks for:
|
||||||
2. If user is in budget
|
2. If user is in budget
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
@ -55,7 +55,7 @@ db_cache_expiry = 5 # refresh every 5s
|
||||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
|
||||||
def common_checks( # noqa: PLR0915
|
async def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: Optional[LiteLLM_TeamTable],
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
user_object: Optional[LiteLLM_UserTable],
|
user_object: Optional[LiteLLM_UserTable],
|
||||||
|
@ -64,6 +64,8 @@ def common_checks( # noqa: PLR0915
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
route: str,
|
route: str,
|
||||||
llm_router: Optional[Router],
|
llm_router: Optional[Router],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Common checks across jwt + key-based auth.
|
Common checks across jwt + key-based auth.
|
||||||
|
@ -80,52 +82,27 @@ def common_checks( # noqa: PLR0915
|
||||||
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
||||||
"""
|
"""
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
|
|
||||||
|
# 1. If team is blocked
|
||||||
if team_object is not None and team_object.blocked is True:
|
if team_object is not None and team_object.blocked is True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. If team can call model
|
# 2. If team can call model
|
||||||
if (
|
_team_model_access_check(
|
||||||
_model is not None
|
team_object=team_object,
|
||||||
and team_object is not None
|
model=_model,
|
||||||
and team_object.models is not None
|
llm_router=llm_router,
|
||||||
and len(team_object.models) > 0
|
)
|
||||||
and _model not in team_object.models
|
|
||||||
):
|
|
||||||
# this means the team has access to all models on the proxy
|
|
||||||
if (
|
|
||||||
"all-proxy-models" in team_object.models
|
|
||||||
or "*" in team_object.models
|
|
||||||
or "openai/*" in team_object.models
|
|
||||||
):
|
|
||||||
# this means the team has access to all models on the proxy
|
|
||||||
pass
|
|
||||||
# check if the team model is an access_group
|
|
||||||
elif (
|
|
||||||
model_in_access_group(
|
|
||||||
model=_model, team_models=team_object.models, llm_router=llm_router
|
|
||||||
)
|
|
||||||
is True
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
elif _model and "*" in _model:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
|
|
||||||
)
|
|
||||||
# 3. If team is in budget
|
# 3. If team is in budget
|
||||||
if (
|
await _team_max_budget_check(
|
||||||
team_object is not None
|
team_object=team_object,
|
||||||
and team_object.max_budget is not None
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
and team_object.spend is not None
|
valid_token=valid_token,
|
||||||
and team_object.spend > team_object.max_budget
|
)
|
||||||
):
|
|
||||||
raise litellm.BudgetExceededError(
|
|
||||||
current_cost=team_object.spend,
|
|
||||||
max_budget=team_object.max_budget,
|
|
||||||
message=f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}",
|
|
||||||
)
|
|
||||||
# 4. If user is in budget
|
# 4. If user is in budget
|
||||||
## 4.1 check personal budget, if personal key
|
## 4.1 check personal budget, if personal key
|
||||||
if (
|
if (
|
||||||
|
@ -140,6 +117,7 @@ def common_checks( # noqa: PLR0915
|
||||||
max_budget=user_budget,
|
max_budget=user_budget,
|
||||||
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
|
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
|
||||||
)
|
)
|
||||||
|
|
||||||
## 4.2 check team member budget, if team key
|
## 4.2 check team member budget, if team key
|
||||||
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||||
|
@ -150,6 +128,7 @@ def common_checks( # noqa: PLR0915
|
||||||
max_budget=end_user_budget,
|
max_budget=end_user_budget,
|
||||||
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
|
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
if (
|
if (
|
||||||
general_settings.get("enforce_user_param", None) is not None
|
general_settings.get("enforce_user_param", None) is not None
|
||||||
|
@ -981,3 +960,144 @@ async def _virtual_key_max_budget_check(
|
||||||
current_cost=valid_token.spend,
|
current_cost=valid_token.spend,
|
||||||
max_budget=valid_token.max_budget,
|
max_budget=valid_token.max_budget,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _virtual_key_soft_budget_check(
|
||||||
|
valid_token: UserAPIKeyAuth,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Triggers a budget alert if the token is over it's soft budget.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
|
||||||
|
valid_token.token,
|
||||||
|
valid_token.spend,
|
||||||
|
valid_token.soft_budget,
|
||||||
|
)
|
||||||
|
call_info = CallInfo(
|
||||||
|
token=valid_token.token,
|
||||||
|
spend=valid_token.spend,
|
||||||
|
max_budget=valid_token.max_budget,
|
||||||
|
soft_budget=valid_token.soft_budget,
|
||||||
|
user_id=valid_token.user_id,
|
||||||
|
team_id=valid_token.team_id,
|
||||||
|
team_alias=valid_token.team_alias,
|
||||||
|
user_email=None,
|
||||||
|
key_alias=valid_token.key_alias,
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
type="soft_budget",
|
||||||
|
user_info=call_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _team_max_budget_check(
|
||||||
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check if the team is over it's max budget.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExceededError if the team is over it's max budget.
|
||||||
|
Triggers a budget alert if the team is over it's max budget.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
team_object is not None
|
||||||
|
and team_object.max_budget is not None
|
||||||
|
and team_object.spend is not None
|
||||||
|
and team_object.spend > team_object.max_budget
|
||||||
|
):
|
||||||
|
if valid_token:
|
||||||
|
call_info = CallInfo(
|
||||||
|
token=valid_token.token,
|
||||||
|
spend=team_object.spend,
|
||||||
|
max_budget=team_object.max_budget,
|
||||||
|
user_id=valid_token.user_id,
|
||||||
|
team_id=valid_token.team_id,
|
||||||
|
team_alias=valid_token.team_alias,
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
type="team_budget",
|
||||||
|
user_info=call_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
raise litellm.BudgetExceededError(
|
||||||
|
current_cost=team_object.spend,
|
||||||
|
max_budget=team_object.max_budget,
|
||||||
|
message=f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _team_model_access_check(
|
||||||
|
model: Optional[str],
|
||||||
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
|
llm_router: Optional[Router],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Access check for team models
|
||||||
|
Raises:
|
||||||
|
Exception if the team is not allowed to call the`model`
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
model is not None
|
||||||
|
and team_object is not None
|
||||||
|
and team_object.models is not None
|
||||||
|
and len(team_object.models) > 0
|
||||||
|
and model not in team_object.models
|
||||||
|
):
|
||||||
|
# this means the team has access to all models on the proxy
|
||||||
|
if (
|
||||||
|
"all-proxy-models" in team_object.models
|
||||||
|
or "*" in team_object.models
|
||||||
|
or "openai/*" in team_object.models
|
||||||
|
):
|
||||||
|
# this means the team has access to all models on the proxy
|
||||||
|
pass
|
||||||
|
# check if the team model is an access_group
|
||||||
|
elif (
|
||||||
|
model_in_access_group(
|
||||||
|
model=model, team_models=team_object.models, llm_router=llm_router
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
elif model and "*" in model:
|
||||||
|
pass
|
||||||
|
elif any(
|
||||||
|
is_model_allowed_by_pattern(model=model, allowed_model_pattern=team_model)
|
||||||
|
for team_model in team_object.models
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Team={team_object.team_id} not allowed to call model={model}. Allowed team models = {team_object.models}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a model matches an allowed pattern.
|
||||||
|
Handles exact matches and wildcard patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620")
|
||||||
|
allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if model matches the pattern, False otherwise
|
||||||
|
"""
|
||||||
|
if "*" in allowed_model_pattern:
|
||||||
|
pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
|
||||||
|
return bool(re.match(pattern, model))
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
|
@ -25,6 +25,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
_handle_failed_db_connection_for_get_key_object,
|
_handle_failed_db_connection_for_get_key_object,
|
||||||
_virtual_key_max_budget_check,
|
_virtual_key_max_budget_check,
|
||||||
|
_virtual_key_soft_budget_check,
|
||||||
allowed_routes_check,
|
allowed_routes_check,
|
||||||
can_key_call_model,
|
can_key_call_model,
|
||||||
common_checks,
|
common_checks,
|
||||||
|
@ -613,7 +614,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
# run through common checks
|
# run through common checks
|
||||||
_ = common_checks(
|
_ = await common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=team_object,
|
team_object=team_object,
|
||||||
user_object=user_object,
|
user_object=user_object,
|
||||||
|
@ -622,6 +623,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
route=route,
|
route=route,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
valid_token=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# return UserAPIKeyAuth object
|
# return UserAPIKeyAuth object
|
||||||
|
@ -1099,30 +1102,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
user_obj=user_obj,
|
user_obj=user_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
|
# Check 5. Soft Budget Check
|
||||||
verbose_proxy_logger.debug(
|
await _virtual_key_soft_budget_check(
|
||||||
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
|
valid_token=valid_token,
|
||||||
valid_token.token,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
valid_token.spend,
|
)
|
||||||
valid_token.soft_budget,
|
|
||||||
)
|
|
||||||
call_info = CallInfo(
|
|
||||||
token=valid_token.token,
|
|
||||||
spend=valid_token.spend,
|
|
||||||
max_budget=valid_token.max_budget,
|
|
||||||
soft_budget=valid_token.soft_budget,
|
|
||||||
user_id=valid_token.user_id,
|
|
||||||
team_id=valid_token.team_id,
|
|
||||||
team_alias=valid_token.team_alias,
|
|
||||||
user_email=None,
|
|
||||||
key_alias=valid_token.key_alias,
|
|
||||||
)
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.budget_alerts(
|
|
||||||
type="soft_budget",
|
|
||||||
user_info=call_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check 5. Token Model Spend is under Model budget
|
# Check 5. Token Model Spend is under Model budget
|
||||||
max_budget_per_model = valid_token.model_max_budget
|
max_budget_per_model = valid_token.model_max_budget
|
||||||
|
@ -1141,35 +1125,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
user_api_key_dict=valid_token,
|
user_api_key_dict=valid_token,
|
||||||
model=current_model,
|
model=current_model,
|
||||||
)
|
)
|
||||||
# Check 6. Team spend is under Team budget
|
|
||||||
if (
|
|
||||||
hasattr(valid_token, "team_spend")
|
|
||||||
and valid_token.team_spend is not None
|
|
||||||
and hasattr(valid_token, "team_max_budget")
|
|
||||||
and valid_token.team_max_budget is not None
|
|
||||||
):
|
|
||||||
call_info = CallInfo(
|
|
||||||
token=valid_token.token,
|
|
||||||
spend=valid_token.team_spend,
|
|
||||||
max_budget=valid_token.team_max_budget,
|
|
||||||
user_id=valid_token.user_id,
|
|
||||||
team_id=valid_token.team_id,
|
|
||||||
team_alias=valid_token.team_alias,
|
|
||||||
)
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.budget_alerts(
|
|
||||||
type="team_budget",
|
|
||||||
user_info=call_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if valid_token.team_spend >= valid_token.team_max_budget:
|
# Check 6: Additional Common Checks across jwt + key auth
|
||||||
raise litellm.BudgetExceededError(
|
|
||||||
current_cost=valid_token.team_spend,
|
|
||||||
max_budget=valid_token.team_max_budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check 8: Additional Common Checks across jwt + key auth
|
|
||||||
if valid_token.team_id is not None:
|
if valid_token.team_id is not None:
|
||||||
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
|
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
|
||||||
team_id=valid_token.team_id,
|
team_id=valid_token.team_id,
|
||||||
|
@ -1184,7 +1141,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
else:
|
else:
|
||||||
_team_obj = None
|
_team_obj = None
|
||||||
|
|
||||||
# Check 9: Check if key is a service account key
|
# Check 7: Check if key is a service account key
|
||||||
await service_account_checks(
|
await service_account_checks(
|
||||||
valid_token=valid_token,
|
valid_token=valid_token,
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
|
@ -1228,7 +1185,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
user_info=call_info,
|
user_info=call_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_ = common_checks(
|
_ = await common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=_team_obj,
|
team_object=_team_obj,
|
||||||
user_object=user_obj,
|
user_object=user_obj,
|
||||||
|
@ -1237,6 +1194,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
route=route,
|
route=route,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
valid_token=valid_token,
|
||||||
)
|
)
|
||||||
# Token passed all checks
|
# Token passed all checks
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
|
|
|
@ -19,8 +19,19 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.auth.auth_checks import get_end_user_object
|
from litellm.proxy.auth.auth_checks import get_end_user_object
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.proxy._types import LiteLLM_EndUserTable, LiteLLM_BudgetTable
|
from litellm.proxy._types import (
|
||||||
|
LiteLLM_EndUserTable,
|
||||||
|
LiteLLM_BudgetTable,
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
)
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
from litellm.proxy.auth.auth_checks import (
|
||||||
|
_team_model_access_check,
|
||||||
|
_virtual_key_soft_budget_check,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy.utils import CallInfo
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
||||||
|
@ -229,3 +240,192 @@ async def test_is_valid_fallback_model():
|
||||||
pytest.fail("Expected is_valid_fallback_model to fail")
|
pytest.fail("Expected is_valid_fallback_model to fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert "Invalid" in str(e)
|
assert "Invalid" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"token_spend, max_budget, expect_budget_error",
|
||||||
|
[
|
||||||
|
(5.0, 10.0, False), # Under budget
|
||||||
|
(10.0, 10.0, True), # At budget limit
|
||||||
|
(15.0, 10.0, True), # Over budget
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_virtual_key_max_budget_check(
|
||||||
|
token_spend, max_budget, expect_budget_error
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test if virtual key budget checks work as expected:
|
||||||
|
1. Triggers budget alert for all cases
|
||||||
|
2. Raises BudgetExceededError when spend >= max_budget
|
||||||
|
"""
|
||||||
|
from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
|
# Setup test data
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token="test-token",
|
||||||
|
spend=token_spend,
|
||||||
|
max_budget=max_budget,
|
||||||
|
user_id="test-user",
|
||||||
|
key_alias="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_obj = LiteLLM_UserTable(
|
||||||
|
user_id="test-user",
|
||||||
|
user_email="test@email.com",
|
||||||
|
max_budget=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_logging_obj = ProxyLogging(
|
||||||
|
user_api_key_cache=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track if budget alert was called
|
||||||
|
alert_called = False
|
||||||
|
|
||||||
|
async def mock_budget_alert(*args, **kwargs):
|
||||||
|
nonlocal alert_called
|
||||||
|
alert_called = True
|
||||||
|
|
||||||
|
proxy_logging_obj.budget_alerts = mock_budget_alert
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _virtual_key_max_budget_check(
|
||||||
|
valid_token=valid_token,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_obj=user_obj,
|
||||||
|
)
|
||||||
|
if expect_budget_error:
|
||||||
|
pytest.fail(
|
||||||
|
f"Expected BudgetExceededError for spend={token_spend}, max_budget={max_budget}"
|
||||||
|
)
|
||||||
|
except litellm.BudgetExceededError as e:
|
||||||
|
if not expect_budget_error:
|
||||||
|
pytest.fail(
|
||||||
|
f"Unexpected BudgetExceededError for spend={token_spend}, max_budget={max_budget}"
|
||||||
|
)
|
||||||
|
assert e.current_cost == token_spend
|
||||||
|
assert e.max_budget == max_budget
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Verify budget alert was triggered
|
||||||
|
assert alert_called, "Budget alert should be triggered"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, team_models, expect_to_work",
|
||||||
|
[
|
||||||
|
("gpt-4", ["gpt-4"], True), # exact match
|
||||||
|
("gpt-4", ["all-proxy-models"], True), # all-proxy-models access
|
||||||
|
("gpt-4", ["*"], True), # wildcard access
|
||||||
|
("gpt-4", ["openai/*"], True), # openai wildcard access
|
||||||
|
(
|
||||||
|
"bedrock/anthropic.claude-3-5-sonnet-20240620",
|
||||||
|
["bedrock/*"],
|
||||||
|
True,
|
||||||
|
), # wildcard access
|
||||||
|
(
|
||||||
|
"bedrockz/anthropic.claude-3-5-sonnet-20240620",
|
||||||
|
["bedrock/*"],
|
||||||
|
False,
|
||||||
|
), # non-match wildcard access
|
||||||
|
("bedrock/very_new_model", ["bedrock/*"], True), # bedrock wildcard access
|
||||||
|
(
|
||||||
|
"bedrock/claude-3-5-sonnet-20240620",
|
||||||
|
["bedrock/claude-*"],
|
||||||
|
True,
|
||||||
|
), # match on pattern
|
||||||
|
(
|
||||||
|
"bedrock/claude-3-6-sonnet-20240620",
|
||||||
|
["bedrock/claude-3-5-*"],
|
||||||
|
False,
|
||||||
|
), # don't match on pattern
|
||||||
|
("openai/gpt-4o", ["openai/*"], True), # openai wildcard access
|
||||||
|
("gpt-4", ["gpt-3.5-turbo"], False), # model not in allowed list
|
||||||
|
("claude-3", [], True), # empty model list (allows all)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_access_check(model, team_models, expect_to_work):
|
||||||
|
"""
|
||||||
|
Test cases for _team_model_access_check:
|
||||||
|
1. Exact model match
|
||||||
|
2. all-proxy-models access
|
||||||
|
3. Wildcard (*) access
|
||||||
|
4. OpenAI wildcard access
|
||||||
|
5. Model not in allowed list
|
||||||
|
6. Empty model list
|
||||||
|
7. None model list
|
||||||
|
"""
|
||||||
|
team_object = LiteLLM_TeamTable(
|
||||||
|
team_id="test-team",
|
||||||
|
models=team_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_team_model_access_check(
|
||||||
|
model=model,
|
||||||
|
team_object=team_object,
|
||||||
|
llm_router=None,
|
||||||
|
)
|
||||||
|
if not expect_to_work:
|
||||||
|
pytest.fail(
|
||||||
|
f"Expected model access check to fail for model={model}, team_models={team_models}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if expect_to_work:
|
||||||
|
pytest.fail(
|
||||||
|
f"Expected model access check to work for model={model}, team_models={team_models}. Got error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spend, soft_budget, expect_alert",
|
||||||
|
[
|
||||||
|
(100, 50, True), # Over soft budget
|
||||||
|
(50, 50, True), # At soft budget
|
||||||
|
(25, 50, False), # Under soft budget
|
||||||
|
(100, None, False), # No soft budget set
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
|
||||||
|
"""
|
||||||
|
Test cases for _virtual_key_soft_budget_check:
|
||||||
|
1. Spend over soft budget
|
||||||
|
2. Spend at soft budget
|
||||||
|
3. Spend under soft budget
|
||||||
|
4. No soft budget set
|
||||||
|
"""
|
||||||
|
alert_triggered = False
|
||||||
|
|
||||||
|
class MockProxyLogging:
|
||||||
|
async def budget_alerts(self, type, user_info):
|
||||||
|
nonlocal alert_triggered
|
||||||
|
alert_triggered = True
|
||||||
|
assert type == "soft_budget"
|
||||||
|
assert isinstance(user_info, CallInfo)
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token="test-token",
|
||||||
|
spend=spend,
|
||||||
|
soft_budget=soft_budget,
|
||||||
|
user_id="test-user",
|
||||||
|
team_id="test-team",
|
||||||
|
key_alias="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_logging_obj = MockProxyLogging()
|
||||||
|
|
||||||
|
await _virtual_key_soft_budget_check(
|
||||||
|
valid_token=valid_token,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1) # Allow time for the alert task to complete
|
||||||
|
|
||||||
|
assert (
|
||||||
|
alert_triggered == expect_alert
|
||||||
|
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
Loading…
Add table
Add a link
Reference in a new issue