fix(user_api_key_auth.py): respect team budgets over user budget, if key belongs to team

Closes https://github.com/BerriAI/litellm/issues/5097
This commit is contained in:
Krrish Dholakia 2024-08-07 14:32:27 -07:00
parent f579aef740
commit d832327ccf
3 changed files with 92 additions and 110 deletions

View file

@ -552,6 +552,7 @@ async def user_api_key_auth(
key=api_key
)
if valid_token is None:
user_obj: Optional[LiteLLM_UserTable] = None
## check db
verbose_proxy_logger.debug("api key: %s", api_key)
if prisma_client is not None:
@ -650,114 +651,26 @@ async def user_api_key_auth(
valid_token=valid_token,
)
# Check 2. If user_id for this token is in budget
# Check 2. If user_id for this token is in budget - done in common_checks()
if valid_token.user_id is not None:
user_id_list = [valid_token.user_id]
for id in user_id_list:
value = user_api_key_cache.get_cache(key=id)
if value is not None:
if user_id_information is None:
user_id_information = []
user_id_information.append(value)
if user_id_information is None or (
isinstance(user_id_information, list)
and len(user_id_information) < 1
):
if prisma_client is not None:
user_id_information = await prisma_client.get_data(
user_id_list=[
valid_token.user_id,
],
table_name="user",
query_type="find_all",
)
if user_id_information is not None:
for _id in user_id_information:
await user_api_key_cache.async_set_cache(
key=_id["user_id"],
value=_id,
)
user_obj = await get_user_object(
user_id=valid_token.user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# value = user_api_key_cache.get_cache(key=id)
if user_obj is not None:
if user_id_information is None:
user_id_information = []
user_id_information.append(user_obj.model_dump())
verbose_proxy_logger.debug(
f"user_id_information: {user_id_information}"
)
if user_id_information is not None:
if isinstance(user_id_information, list):
## Check if user in budget
for _user in user_id_information:
if _user is None:
continue
assert isinstance(_user, dict)
# check if user is admin #
# Token exists, not expired now check if its in budget for the user
user_max_budget = _user.get("max_budget", None)
user_current_spend = _user.get("spend", None)
verbose_proxy_logger.debug(
f"user_id: {_user.get('user_id', None)}; user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
)
if (
user_max_budget is not None
and user_current_spend is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=user_current_spend,
max_budget=user_max_budget,
user_id=_user.get("user_id", None),
user_email=_user.get("user_email", None),
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="user_budget",
user_info=call_info,
)
)
_user_id = _user.get("user_id", None)
if user_current_spend > user_max_budget:
raise litellm.BudgetExceededError(
current_cost=user_current_spend,
max_budget=user_max_budget,
)
else:
# Token exists, not expired now check if its in budget for the user
user_max_budget = getattr(
user_id_information, "max_budget", None
)
user_current_spend = getattr(user_id_information, "spend", None)
if (
user_max_budget is not None
and user_current_spend is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=user_current_spend,
max_budget=user_max_budget,
user_id=getattr(user_id_information, "user_id", None),
user_email=getattr(
user_id_information, "user_email", None
),
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="user_budget",
user_info=call_info,
)
)
if user_current_spend > user_max_budget:
raise litellm.BudgetExceededError(
current_cost=user_current_spend,
max_budget=user_max_budget,
)
# Check 3. Check if user is in their team budget
if valid_token.team_member_spend is not None:
if prisma_client is not None:
@ -983,7 +896,7 @@ async def user_api_key_auth(
_ = common_checks(
request_body=request_data,
team_object=_team_obj,
user_object=None,
user_object=user_obj,
end_user_object=_end_user_object,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,