mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(Testing + Refactor) - Unit testing for team and virtual key budget checks (#7945)
* unit testing for test_virtual_key_max_budget_check * refactor _team_max_budget_check * is_model_allowed_by_pattern
This commit is contained in:
parent
8479d05d49
commit
e0fb7eb4f7
3 changed files with 377 additions and 98 deletions
|
@ -25,6 +25,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
_cache_key_object,
|
||||
_handle_failed_db_connection_for_get_key_object,
|
||||
_virtual_key_max_budget_check,
|
||||
_virtual_key_soft_budget_check,
|
||||
allowed_routes_check,
|
||||
can_key_call_model,
|
||||
common_checks,
|
||||
|
@ -613,7 +614,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
_ = await common_checks(
|
||||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
|
@ -622,6 +623,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
valid_token=None,
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
|
@ -1099,30 +1102,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
user_obj=user_obj,
|
||||
)
|
||||
|
||||
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
|
||||
verbose_proxy_logger.debug(
|
||||
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
|
||||
valid_token.token,
|
||||
valid_token.spend,
|
||||
valid_token.soft_budget,
|
||||
)
|
||||
call_info = CallInfo(
|
||||
token=valid_token.token,
|
||||
spend=valid_token.spend,
|
||||
max_budget=valid_token.max_budget,
|
||||
soft_budget=valid_token.soft_budget,
|
||||
user_id=valid_token.user_id,
|
||||
team_id=valid_token.team_id,
|
||||
team_alias=valid_token.team_alias,
|
||||
user_email=None,
|
||||
key_alias=valid_token.key_alias,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="soft_budget",
|
||||
user_info=call_info,
|
||||
)
|
||||
)
|
||||
# Check 5. Soft Budget Check
|
||||
await _virtual_key_soft_budget_check(
|
||||
valid_token=valid_token,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Check 5. Token Model Spend is under Model budget
|
||||
max_budget_per_model = valid_token.model_max_budget
|
||||
|
@ -1141,35 +1125,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
user_api_key_dict=valid_token,
|
||||
model=current_model,
|
||||
)
|
||||
# Check 6. Team spend is under Team budget
|
||||
if (
|
||||
hasattr(valid_token, "team_spend")
|
||||
and valid_token.team_spend is not None
|
||||
and hasattr(valid_token, "team_max_budget")
|
||||
and valid_token.team_max_budget is not None
|
||||
):
|
||||
call_info = CallInfo(
|
||||
token=valid_token.token,
|
||||
spend=valid_token.team_spend,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
user_id=valid_token.user_id,
|
||||
team_id=valid_token.team_id,
|
||||
team_alias=valid_token.team_alias,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="team_budget",
|
||||
user_info=call_info,
|
||||
)
|
||||
)
|
||||
|
||||
if valid_token.team_spend >= valid_token.team_max_budget:
|
||||
raise litellm.BudgetExceededError(
|
||||
current_cost=valid_token.team_spend,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
)
|
||||
|
||||
# Check 8: Additional Common Checks across jwt + key auth
|
||||
# Check 6: Additional Common Checks across jwt + key auth
|
||||
if valid_token.team_id is not None:
|
||||
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
|
||||
team_id=valid_token.team_id,
|
||||
|
@ -1184,7 +1141,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
else:
|
||||
_team_obj = None
|
||||
|
||||
# Check 9: Check if key is a service account key
|
||||
# Check 7: Check if key is a service account key
|
||||
await service_account_checks(
|
||||
valid_token=valid_token,
|
||||
request_data=request_data,
|
||||
|
@ -1228,7 +1185,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
user_info=call_info,
|
||||
)
|
||||
)
|
||||
_ = common_checks(
|
||||
_ = await common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
user_object=user_obj,
|
||||
|
@ -1237,6 +1194,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
# Token passed all checks
|
||||
if valid_token is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue