fix(proxy_server.py): respect internal_user_budget_duration for sso user

This commit is contained in:
Krrish Dholakia 2024-08-08 17:28:28 -07:00
parent 75ac9323e8
commit 488a78e5f4
2 changed files with 24 additions and 0 deletions

View file

@ -1752,3 +1752,4 @@ class SSOUserDefinedValues(TypedDict):
user_email: Optional[str] user_email: Optional[str]
user_role: Optional[str] user_role: Optional[str]
max_budget: Optional[float] max_budget: Optional[float]
budget_duration: Optional[str]

View file

@ -8634,6 +8634,7 @@ async def auth_callback(request: Request):
user_info = None user_info = None
user_id_models: List = [] user_id_models: List = []
max_internal_user_budget = litellm.max_internal_user_budget max_internal_user_budget = litellm.max_internal_user_budget
internal_user_budget_duration = litellm.internal_user_budget_duration
# User might not be already created on first generation of key # User might not be already created on first generation of key
# But if it is, we want their models preferences # But if it is, we want their models preferences
@ -8651,6 +8652,7 @@ async def auth_callback(request: Request):
"user_email": user_email, "user_email": user_email,
"max_budget": max_internal_user_budget, "max_budget": max_internal_user_budget,
"user_role": None, "user_role": None,
"budget_duration": internal_user_budget_duration,
} }
_user_id_from_sso = user_id _user_id_from_sso = user_id
try: try:
@ -8669,6 +8671,9 @@ async def auth_callback(request: Request):
"max_budget": getattr( "max_budget": getattr(
user_info, "max_budget", max_internal_user_budget user_info, "max_budget", max_internal_user_budget
), ),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
} }
user_role = getattr(user_info, "user_role", None) user_role = getattr(user_info, "user_role", None)
@ -8685,6 +8690,9 @@ async def auth_callback(request: Request):
"max_budget": getattr( "max_budget": getattr(
user_info, "max_budget", max_internal_user_budget user_info, "max_budget", max_internal_user_budget
), ),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
} }
user_role = getattr(user_info, "user_role", None) user_role = getattr(user_info, "user_role", None)
@ -8705,11 +8713,26 @@ async def auth_callback(request: Request):
"max_budget": litellm.default_user_params.get( "max_budget": litellm.default_user_params.get(
"max_budget", max_internal_user_budget "max_budget", max_internal_user_budget
), ),
"budget_duration": litellm.default_user_params.get(
"budget_duration", internal_user_budget_duration
),
} }
except Exception as e: except Exception as e:
pass pass
if (
user_defined_values["max_budget"] is None
and litellm.max_internal_user_budget is not None
):
user_defined_values["max_budget"] = litellm.max_internal_user_budget
if (
user_defined_values["budget_duration"] is None
and litellm.internal_user_budget_duration is not None
):
user_defined_values["budget_duration"] = litellm.internal_user_budget_duration
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"user_defined_values for creating ui key: {user_defined_values}" f"user_defined_values for creating ui key: {user_defined_values}"
) )