(Testing + Refactor) - Unit testing for team and virtual key budget checks (#7945)

* unit testing for test_virtual_key_max_budget_check

* refactor _team_max_budget_check

* is_model_allowed_by_pattern
This commit is contained in:
Ishaan Jaff 2025-01-23 16:58:16 -08:00 committed by GitHub
parent 8479d05d49
commit e0fb7eb4f7
3 changed files with 377 additions and 98 deletions

View file

@ -25,6 +25,7 @@ from litellm.proxy.auth.auth_checks import (
_cache_key_object,
_handle_failed_db_connection_for_get_key_object,
_virtual_key_max_budget_check,
_virtual_key_soft_budget_check,
allowed_routes_check,
can_key_call_model,
common_checks,
@ -613,7 +614,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
parent_otel_span=parent_otel_span,
)
# run through common checks
_ = common_checks(
_ = await common_checks(
request_body=request_data,
team_object=team_object,
user_object=user_object,
@ -622,6 +623,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
global_proxy_spend=global_proxy_spend,
route=route,
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=None,
)
# return UserAPIKeyAuth object
@ -1099,30 +1102,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_obj=user_obj,
)
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
verbose_proxy_logger.debug(
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
valid_token.token,
valid_token.spend,
valid_token.soft_budget,
)
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
user_email=None,
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
# Check 5. Soft Budget Check
await _virtual_key_soft_budget_check(
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
# Check 5. Token Model Spend is under Model budget
max_budget_per_model = valid_token.model_max_budget
@ -1141,35 +1125,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_api_key_dict=valid_token,
model=current_model,
)
# Check 6. Team spend is under Team budget
if (
hasattr(valid_token, "team_spend")
and valid_token.team_spend is not None
and hasattr(valid_token, "team_max_budget")
and valid_token.team_max_budget is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.team_spend,
max_budget=valid_token.team_max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="team_budget",
user_info=call_info,
)
)
if valid_token.team_spend >= valid_token.team_max_budget:
raise litellm.BudgetExceededError(
current_cost=valid_token.team_spend,
max_budget=valid_token.team_max_budget,
)
# Check 8: Additional Common Checks across jwt + key auth
# Check 6: Additional Common Checks across jwt + key auth
if valid_token.team_id is not None:
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
team_id=valid_token.team_id,
@ -1184,7 +1141,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
else:
_team_obj = None
# Check 9: Check if key is a service account key
# Check 7: Check if key is a service account key
await service_account_checks(
valid_token=valid_token,
request_data=request_data,
@ -1228,7 +1185,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_info=call_info,
)
)
_ = common_checks(
_ = await common_checks(
request_body=request_data,
team_object=_team_obj,
user_object=user_obj,
@ -1237,6 +1194,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
global_proxy_spend=global_proxy_spend,
route=route,
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# Token passed all checks
if valid_token is None: