diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 17d20c8c8f..e1a9c2e6e5 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -175,7 +175,7 @@ async def get_generic_sso_response( jwt_handler: JWTHandler, generic_client_id: str, redirect_url: str, -) -> Optional[OpenID]: +) -> Union[OpenID, dict]: # make generic sso provider from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider @@ -252,7 +252,7 @@ async def get_generic_sso_response( request, params={"include_client_id": generic_include_client_id} ) verbose_proxy_logger.debug("generic result: %s", result) - return result + return result or {} async def create_team_member_add_task(team_id, user_info): @@ -571,7 +571,7 @@ async def auth_callback(request: Request): async def insert_sso_user( - result_openid: Optional[OpenID], + result_openid: Optional[Union[OpenID, dict]], user_defined_values: Optional[SSOUserDefinedValues] = None, ) -> NewUserResponse: """ @@ -587,6 +587,10 @@ async def insert_sso_user( verbose_proxy_logger.debug( f"Inserting SSO user into DB. User values: {user_defined_values}" ) + if result_openid is None: + raise ValueError("result_openid is None") + if isinstance(result_openid, dict): + result_openid = OpenID(**result_openid) if user_defined_values is None: raise ValueError("user_defined_values is None") @@ -840,7 +844,14 @@ class MicrosoftSSOHandler: microsoft_client_id: str, redirect_url: str, jwt_handler: JWTHandler, - ) -> CustomOpenID: + return_raw_sso_response: bool = False, + ) -> Union[CustomOpenID, OpenID, dict]: + """ + Get the Microsoft SSO callback response + + Args: + return_raw_sso_response: If True, return the raw SSO response + """ from fastapi_sso.sso.microsoft import MicrosoftSSO microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) @@ -870,6 +881,11 @@ class MicrosoftSSOHandler: request=request, convert_response=False, ) + + # if user is trying to get the raw sso response for debugging, return the raw sso response + if return_raw_sso_response: + return original_msft_result or {} + result = MicrosoftSSOHandler.openid_from_response( response=original_msft_result, jwt_handler=jwt_handler, @@ -905,7 +921,14 @@ class GoogleSSOHandler: request: Request, google_client_id: str, redirect_url: str, - ) -> Optional[OpenID]: + return_raw_sso_response: bool = False, + ) -> Union[OpenID, dict]: + """ + Get the Google SSO callback response + + Args: + return_raw_sso_response: If True, return the raw SSO response + """ from fastapi_sso.sso.google import GoogleSSO google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) @@ -921,8 +944,19 @@ class GoogleSSOHandler: redirect_uri=redirect_url, client_secret=google_client_secret, ) + + # if user is trying to get the raw sso response for debugging, return the raw sso response + if return_raw_sso_response: + return ( + await google_sso.verify_and_process( + request=request, + convert_response=False, + ) + or {} + ) + result = await google_sso.verify_and_process(request) - return result + return result or {} @router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False) @@ -1002,6 +1036,7 @@ async def debug_sso_callback(request: Request): request=request, google_client_id=google_client_id, redirect_url=redirect_url, + return_raw_sso_response=True, ) elif microsoft_client_id is not None: result = await MicrosoftSSOHandler.get_microsoft_callback_response( @@ -1009,6 +1044,7 @@ async def debug_sso_callback(request: Request): microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, jwt_handler=jwt_handler, + return_raw_sso_response=True, ) elif generic_client_id is not None: result = await get_generic_sso_response(