fix(proxy_server.py): respect internal_user_budget_duration for sso user

This commit is contained in:
Krrish Dholakia 2024-08-08 17:28:28 -07:00
parent 75ac9323e8
commit 488a78e5f4
2 changed files with 24 additions and 0 deletions

View file

@ -8634,6 +8634,7 @@ async def auth_callback(request: Request):
user_info = None
user_id_models: List = []
max_internal_user_budget = litellm.max_internal_user_budget
internal_user_budget_duration = litellm.internal_user_budget_duration
# User might not be already created on first generation of key
# But if it is, we want their models preferences
@ -8651,6 +8652,7 @@ async def auth_callback(request: Request):
"user_email": user_email,
"max_budget": max_internal_user_budget,
"user_role": None,
"budget_duration": internal_user_budget_duration,
}
_user_id_from_sso = user_id
try:
@ -8669,6 +8671,9 @@ async def auth_callback(request: Request):
"max_budget": getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
user_role = getattr(user_info, "user_role", None)
@ -8685,6 +8690,9 @@ async def auth_callback(request: Request):
"max_budget": getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
user_role = getattr(user_info, "user_role", None)
@ -8705,11 +8713,26 @@ async def auth_callback(request: Request):
"max_budget": litellm.default_user_params.get(
"max_budget", max_internal_user_budget
),
"budget_duration": litellm.default_user_params.get(
"budget_duration", internal_user_budget_duration
),
}
except Exception as e:
pass
if (
user_defined_values["max_budget"] is None
and litellm.max_internal_user_budget is not None
):
user_defined_values["max_budget"] = litellm.max_internal_user_budget
if (
user_defined_values["budget_duration"] is None
and litellm.internal_user_budget_duration is not None
):
user_defined_values["budget_duration"] = litellm.internal_user_budget_duration
verbose_proxy_logger.info(
f"user_defined_values for creating ui key: {user_defined_values}"
)