diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index af7dcfe608..4282a44941 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -485,7 +485,14 @@ async def auth_callback(request: Request): # noqa: PLR0915 redirect_uri=redirect_url, allow_insecure_http=True, ) - result = await microsoft_sso.verify_and_process(request) + original_msft_result = await microsoft_sso.verify_and_process( + request=request, + convert_response=False, + ) + result = MicrosoftSSOHandler.openid_from_response( + response=original_msft_result, + jwt_handler=jwt_handler, + ) elif generic_client_id is not None: result = await get_generic_sso_response( request=request, @@ -495,6 +502,8 @@ async def auth_callback(request: Request): # noqa: PLR0915 ) # User is Authe'd in - generate key for the UI to access Proxy verbose_proxy_logger.debug(f"SSO callback result: {result}") + result = cast(CustomOpenID, result) + result.team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, result)) user_email: Optional[str] = getattr(result, "email", None) user_id: Optional[str] = getattr(result, "id", None) if result is not None else None @@ -780,3 +789,27 @@ async def get_ui_settings(request: Request): ), "DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries, } + + +class MicrosoftSSOHandler: + """ + Handles Microsoft SSO callback response and returns a CustomOpenID object + """ + + @staticmethod + def openid_from_response( + response: Optional[dict], jwt_handler: JWTHandler + ) -> CustomOpenID: + response = response or {} + verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") + openid_response = CustomOpenID( + email=response.get("mail"), + display_name=response.get("displayName"), + provider="microsoft", + id=response.get("id"), + first_name=response.get("givenName"), + last_name=response.get("surname"), + team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)), + ) + verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") + return openid_response