mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(user_api_key_auth.py): respect team budgets over user budget, if key belongs to team
Closes https://github.com/BerriAI/litellm/issues/5097
This commit is contained in:
parent
f579aef740
commit
d832327ccf
3 changed files with 92 additions and 110 deletions
|
@ -552,6 +552,7 @@ async def user_api_key_auth(
|
|||
key=api_key
|
||||
)
|
||||
if valid_token is None:
|
||||
user_obj: Optional[LiteLLM_UserTable] = None
|
||||
## check db
|
||||
verbose_proxy_logger.debug("api key: %s", api_key)
|
||||
if prisma_client is not None:
|
||||
|
@ -650,114 +651,26 @@ async def user_api_key_auth(
|
|||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget
|
||||
# Check 2. If user_id for this token is in budget - done in common_checks()
|
||||
if valid_token.user_id is not None:
|
||||
user_id_list = [valid_token.user_id]
|
||||
for id in user_id_list:
|
||||
value = user_api_key_cache.get_cache(key=id)
|
||||
if value is not None:
|
||||
if user_id_information is None:
|
||||
user_id_information = []
|
||||
user_id_information.append(value)
|
||||
if user_id_information is None or (
|
||||
isinstance(user_id_information, list)
|
||||
and len(user_id_information) < 1
|
||||
):
|
||||
if prisma_client is not None:
|
||||
user_id_information = await prisma_client.get_data(
|
||||
user_id_list=[
|
||||
valid_token.user_id,
|
||||
],
|
||||
table_name="user",
|
||||
query_type="find_all",
|
||||
)
|
||||
if user_id_information is not None:
|
||||
for _id in user_id_information:
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id["user_id"],
|
||||
value=_id,
|
||||
)
|
||||
user_obj = await get_user_object(
|
||||
user_id=valid_token.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# value = user_api_key_cache.get_cache(key=id)
|
||||
if user_obj is not None:
|
||||
if user_id_information is None:
|
||||
user_id_information = []
|
||||
user_id_information.append(user_obj.model_dump())
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id_information: {user_id_information}"
|
||||
)
|
||||
|
||||
if user_id_information is not None:
|
||||
if isinstance(user_id_information, list):
|
||||
## Check if user in budget
|
||||
for _user in user_id_information:
|
||||
if _user is None:
|
||||
continue
|
||||
assert isinstance(_user, dict)
|
||||
# check if user is admin #
|
||||
|
||||
# Token exists, not expired now check if its in budget for the user
|
||||
user_max_budget = _user.get("max_budget", None)
|
||||
user_current_spend = _user.get("spend", None)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {_user.get('user_id', None)}; user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
|
||||
)
|
||||
|
||||
if (
|
||||
user_max_budget is not None
|
||||
and user_current_spend is not None
|
||||
):
|
||||
call_info = CallInfo(
|
||||
token=valid_token.token,
|
||||
spend=user_current_spend,
|
||||
max_budget=user_max_budget,
|
||||
user_id=_user.get("user_id", None),
|
||||
user_email=_user.get("user_email", None),
|
||||
key_alias=valid_token.key_alias,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="user_budget",
|
||||
user_info=call_info,
|
||||
)
|
||||
)
|
||||
|
||||
_user_id = _user.get("user_id", None)
|
||||
if user_current_spend > user_max_budget:
|
||||
raise litellm.BudgetExceededError(
|
||||
current_cost=user_current_spend,
|
||||
max_budget=user_max_budget,
|
||||
)
|
||||
else:
|
||||
# Token exists, not expired now check if its in budget for the user
|
||||
user_max_budget = getattr(
|
||||
user_id_information, "max_budget", None
|
||||
)
|
||||
user_current_spend = getattr(user_id_information, "spend", None)
|
||||
|
||||
if (
|
||||
user_max_budget is not None
|
||||
and user_current_spend is not None
|
||||
):
|
||||
call_info = CallInfo(
|
||||
token=valid_token.token,
|
||||
spend=user_current_spend,
|
||||
max_budget=user_max_budget,
|
||||
user_id=getattr(user_id_information, "user_id", None),
|
||||
user_email=getattr(
|
||||
user_id_information, "user_email", None
|
||||
),
|
||||
key_alias=valid_token.key_alias,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="user_budget",
|
||||
user_info=call_info,
|
||||
)
|
||||
)
|
||||
|
||||
if user_current_spend > user_max_budget:
|
||||
raise litellm.BudgetExceededError(
|
||||
current_cost=user_current_spend,
|
||||
max_budget=user_max_budget,
|
||||
)
|
||||
|
||||
# Check 3. Check if user is in their team budget
|
||||
if valid_token.team_member_spend is not None:
|
||||
if prisma_client is not None:
|
||||
|
@ -983,7 +896,7 @@ async def user_api_key_auth(
|
|||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
user_object=None,
|
||||
user_object=user_obj,
|
||||
end_user_object=_end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue