From edf403b537dd157073c678e68f9de98487412a3f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 9 Apr 2025 16:18:16 -0700 Subject: [PATCH] ui sso fix team assignments --- litellm/proxy/management_endpoints/ui_sso.py | 105 +++++++++++++------ 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index fb5550e5d7..af32ca7fb2 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -55,6 +55,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import ( ) from litellm.proxy.management_endpoints.team_endpoints import team_member_add from litellm.proxy.management_endpoints.types import CustomOpenID +from litellm.proxy.utils import PrismaClient from litellm.secret_managers.main import str_to_bool from litellm.types.proxy.management_endpoints.ui_sso import * @@ -461,40 +462,22 @@ async def auth_callback(request: Request): # noqa: PLR0915 f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}" ) - if user_info is not None: - user_id = user_info.user_id - user_defined_values = SSOUserDefinedValues( - models=getattr(user_info, "models", user_id_models), - user_id=user_info.user_id, - user_email=getattr(user_info, "user_email", user_email), - user_role=getattr(user_info, "user_role", None), - max_budget=getattr( - user_info, "max_budget", max_internal_user_budget - ), - budget_duration=getattr( - user_info, "budget_duration", internal_user_budget_duration - ), - ) - - user_role = getattr(user_info, "user_role", None) - - # update id - await prisma_client.db.litellm_usertable.update_many( - where={"user_email": user_email}, data={"user_id": user_id} # type: ignore - ) + # Upsert SSO User to LiteLLM DB + user_info = await SSOAuthenticationHandler.upsert_sso_user( + result=result, + user_info=user_info, + user_email=user_email, + user_id_models=user_id_models, + max_internal_user_budget=max_internal_user_budget, + internal_user_budget_duration=internal_user_budget_duration, + user_defined_values=user_defined_values, + prisma_client=prisma_client, + ) + if user_info and user_info.user_role is not None: + user_role = user_info.user_role else: - verbose_proxy_logger.info( - "user not in DB, inserting user into LiteLLM DB" - ) - # user not in DB, insert User into LiteLLM DB - user_info = await insert_sso_user( - result_openid=result, - user_defined_values=user_defined_values, - ) + user_role = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY - user_role = ( - user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY - ) await SSOAuthenticationHandler.add_user_to_teams_from_sso_response( result=result, user_info=user_info, @@ -842,10 +825,61 @@ class SSOAuthenticationHandler: redirect_url += "/" + sso_callback_route return redirect_url + @staticmethod + async def upsert_sso_user( + result: Optional[Union[CustomOpenID, OpenID, dict]], + user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], + user_email: Optional[str], + user_id_models: List[str], + max_internal_user_budget: Optional[float], + internal_user_budget_duration: Optional[str], + user_defined_values: Optional[SSOUserDefinedValues], + prisma_client: PrismaClient, + ): + """ + Connects the SSO Users to the User Table in LiteLLM DB + + - If user on LiteLLM DB, update the user_id with the SSO user_id + - If user not on LiteLLM DB, insert the user into LiteLLM DB + """ + try: + if user_info is not None: + user_id = user_info.user_id + user_defined_values = SSOUserDefinedValues( + models=getattr(user_info, "models", user_id_models), + user_id=user_info.user_id or "", + user_email=getattr(user_info, "user_email", user_email), + user_role=getattr(user_info, "user_role", None), + max_budget=getattr( + user_info, "max_budget", max_internal_user_budget + ), + budget_duration=getattr( + user_info, "budget_duration", internal_user_budget_duration + ), + ) + + # update id + await prisma_client.db.litellm_usertable.update_many( + where={"user_email": user_email}, data={"user_id": user_id} # type: ignore + ) + else: + verbose_proxy_logger.info( + "user not in DB, inserting user into LiteLLM DB" + ) + # user not in DB, insert User into LiteLLM DB + user_info = await insert_sso_user( + result_openid=result, + user_defined_values=user_defined_values, + ) + return user_info + except Exception as e: + verbose_proxy_logger.error(f"Error upserting SSO user into LiteLLM DB: {e}") + return user_info + @staticmethod async def add_user_to_teams_from_sso_response( result: Optional[Union[CustomOpenID, OpenID, dict]], - user_info: Union[NewUserResponse, LiteLLM_UserTable], + user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], ): """ Adds the user as a team member to the teams specified in the SSO responses `team_ids` field @@ -853,6 +887,11 @@ class SSOAuthenticationHandler: The `team_ids` field is populated by litellm after processing the SSO response """ + if user_info is None: + verbose_proxy_logger.debug( + "User not found in LiteLLM DB, skipping team member addition" + ) + return sso_teams = getattr(result, "team_ids", []) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)