diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 604ceee3e5..5e7438585e 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -198,6 +198,7 @@ This budget does not apply to keys created under non-default teams. ### Auto-add SSO users to teams + 1. Specify the JWT field that contains the team ids, that the user belongs to. ```yaml @@ -207,7 +208,8 @@ general_settings: team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD ``` -This is assuming your SSO token looks like this: +This is assuming your SSO token looks like this. **If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions [here](#debugging-sso-jwt-fields)** + ``` { ..., @@ -231,6 +233,39 @@ curl -X POST '/team/new' \ Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e) +### Debugging SSO JWT fields + +If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions. This guide walks you through setting up a debug callback to view the JWT data during the SSO process. + + + +
+ +1. Add `/sso/debug/callback` as a redirect URL in your SSO provider + + In your SSO provider's settings, add the following URL as a new redirect (callback) URL: + + ```bash showLineNumbers title="Redirect URL" + http:///sso/debug/callback + ``` + + +2. Navigate to the debug login page on your browser + + Navigate to the following URL on your browser: + + ```bash showLineNumbers title="URL to navigate to" + https:///sso/debug/login + ``` + + This will initiate the standard SSO flow. You will be redirected to your SSO provider's login screen, and after successful authentication, you will be redirected back to LiteLLM's debug callback route. + + +3. View the JWT fields + +Once redirected, you should see a page called "SSO Debug Information". This page displays the JWT fields received from your SSO provider (as shown in the image above) + + ### Restrict Users from creating personal keys This is useful if you only want users to create keys under a specific team. diff --git a/docs/my-website/img/debug_sso.png b/docs/my-website/img/debug_sso.png new file mode 100644 index 0000000000..d7dde36892 Binary files /dev/null and b/docs/my-website/img/debug_sso.png differ diff --git a/litellm/proxy/common_utils/html_forms/jwt_display_template.py b/litellm/proxy/common_utils/html_forms/jwt_display_template.py new file mode 100644 index 0000000000..03dff78dba --- /dev/null +++ b/litellm/proxy/common_utils/html_forms/jwt_display_template.py @@ -0,0 +1,284 @@ +# JWT display template for SSO debug callback +jwt_display_template = """ + + + + + LiteLLM SSO Debug - JWT Information + + + + +
+
+ +
+

SSO Debug Information

+

Results from the SSO authentication process.

+ +
+
+ + + + + Authentication Successful +
+

The SSO authentication completed successfully. Below is the information returned by the provider.

+
+ +
+ +
+ +
+
+ + + + + + JSON Representation +
+
+
Loading...
+
+
+ +
+
+ + + Try Another SSO Login + +
+ + + + +""" diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 970587ded9..c9388bc4eb 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -3,6 +3,9 @@ Has all /sso/* routes /sso/key/generate - handles user signing in with SSO and redirects to /sso/callback /sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI + +/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback +/sso/debug/callback - returns the OpenID object returned by the SSO provider """ import asyncio @@ -36,6 +39,9 @@ from litellm.proxy.common_utils.admin_ui_utils import ( admin_ui_disabled, show_missing_vars_in_env, ) +from litellm.proxy.common_utils.html_forms.jwt_display_template import ( + jwt_display_template, +) from litellm.proxy.common_utils.html_forms.ui_login import html_form from litellm.proxy.management_endpoints.internal_user_endpoints import new_user from litellm.proxy.management_endpoints.sso_helper_utils import ( @@ -92,131 +98,29 @@ async def google_login(request: Request): # noqa: PLR0915 missing_env_vars = show_missing_vars_in_env() if missing_env_vars is not None: return missing_env_vars + ui_username = os.getenv("UI_USERNAME") # get url from request - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - ui_username = os.getenv("UI_USERNAME") - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - # Google SSO Auth - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO + redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( + request=request, + sso_callback_route="sso/callback", + ) - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - client_secret=google_client_secret, - redirect_uri=redirect_url, + # Check if we should use SSO handler + if ( + SSOAuthenticationHandler.should_use_sso_handler( + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, ) - verbose_proxy_logger.info( - f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + is True + ): + return await SSOAuthenticationHandler.get_sso_login_redirect( + redirect_url=redirect_url, + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, ) - with google_sso: - return await google_sso.get_login_redirect() - # Microsoft SSO Auth - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - with microsoft_sso: - return await microsoft_sso.get_login_redirect() - elif generic_client_id is not None: - from fastapi_sso.sso.base import DiscoveryDocument - from fastapi_sso.sso.generic import create_provider - - generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) - generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") - generic_authorization_endpoint = os.getenv( - "GENERIC_AUTHORIZATION_ENDPOINT", None - ) - generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) - generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) - if generic_client_secret is None: - raise ProxyException( - message="GENERIC_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_authorization_endpoint is None: - raise ProxyException( - message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_AUTHORIZATION_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_token_endpoint is None: - raise ProxyException( - message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_TOKEN_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_userinfo_endpoint is None: - raise ProxyException( - message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_USERINFO_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - verbose_proxy_logger.debug( - f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" - ) - verbose_proxy_logger.debug( - f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" - ) - discovery = DiscoveryDocument( - authorization_endpoint=generic_authorization_endpoint, - token_endpoint=generic_token_endpoint, - userinfo_endpoint=generic_userinfo_endpoint, - ) - SSOProvider = create_provider(name="oidc", discovery_document=discovery) - generic_sso = SSOProvider( - client_id=generic_client_id, - client_secret=generic_client_secret, - redirect_uri=redirect_url, - allow_insecure_http=True, - scope=generic_scope, - ) - with generic_sso: - # TODO: state should be a random string and added to the user session with cookie - # or a cryptographicly signed state that we can verify stateless - # For simplification we are using a static state, this is not perfect but some - # SSO providers do not allow stateless verification - redirect_params = {} - state = os.getenv("GENERIC_CLIENT_STATE", None) - - if state: - redirect_params["state"] = state - elif "okta" in generic_authorization_endpoint: - redirect_params[ - "state" - ] = uuid.uuid4().hex # set state param for okta - required - return await generic_sso.get_login_redirect(**redirect_params) # type: ignore elif ui_username is not None: # No Google, Microsoft SSO # Use UI Credentials set in .env @@ -271,7 +175,7 @@ async def get_generic_sso_response( jwt_handler: JWTHandler, generic_client_id: str, redirect_url: str, -) -> Optional[OpenID]: +) -> Union[OpenID, dict]: # make generic sso provider from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider @@ -348,7 +252,7 @@ async def get_generic_sso_response( request, params={"include_client_id": generic_include_client_id} ) verbose_proxy_logger.debug("generic result: %s", result) - return result + return result or {} async def create_team_member_add_task(team_id, user_info): @@ -443,54 +347,16 @@ async def auth_callback(request: Request): # noqa: PLR0915 result = None if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - redirect_uri=redirect_url, - client_secret=google_client_secret, - ) - result = await google_sso.verify_and_process(request) - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if microsoft_tenant is None: - raise ProxyException( - message="MICROSOFT_TENANT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_TENANT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - original_msft_result = await microsoft_sso.verify_and_process( + result = await GoogleSSOHandler.get_google_callback_response( request=request, - convert_response=False, + google_client_id=google_client_id, + redirect_url=redirect_url, ) - result = MicrosoftSSOHandler.openid_from_response( - response=original_msft_result, + elif microsoft_client_id is not None: + result = await MicrosoftSSOHandler.get_microsoft_callback_response( + request=request, + microsoft_client_id=microsoft_client_id, + redirect_url=redirect_url, jwt_handler=jwt_handler, ) elif generic_client_id is not None: @@ -705,7 +571,7 @@ async def auth_callback(request: Request): # noqa: PLR0915 async def insert_sso_user( - result_openid: Optional[OpenID], + result_openid: Optional[Union[OpenID, dict]], user_defined_values: Optional[SSOUserDefinedValues] = None, ) -> NewUserResponse: """ @@ -721,6 +587,10 @@ async def insert_sso_user( verbose_proxy_logger.debug( f"Inserting SSO user into DB. User values: {user_defined_values}" ) + if result_openid is None: + raise ValueError("result_openid is None") + if isinstance(result_openid, dict): + result_openid = OpenID(**result_openid) if user_defined_values is None: raise ValueError("user_defined_values is None") @@ -733,9 +603,9 @@ async def insert_sso_user( if user_defined_values.get("max_budget") is None: user_defined_values["max_budget"] = litellm.max_internal_user_budget if user_defined_values.get("budget_duration") is None: - user_defined_values[ - "budget_duration" - ] = litellm.internal_user_budget_duration + user_defined_values["budget_duration"] = ( + litellm.internal_user_budget_duration + ) if user_defined_values["user_role"] is None: user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY @@ -789,11 +659,242 @@ async def get_ui_settings(request: Request): } +class SSOAuthenticationHandler: + """ + Handler for SSO Authentication across all SSO providers + """ + + @staticmethod + async def get_sso_login_redirect( + redirect_url: str, + google_client_id: Optional[str] = None, + microsoft_client_id: Optional[str] = None, + generic_client_id: Optional[str] = None, + ) -> Optional[RedirectResponse]: + """ + Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url` + + Args: + redirect_url (str): The URL to redirect the user to after login + google_client_id (Optional[str], optional): The Google Client ID. Defaults to None. + microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None. + generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None. + + Returns: + RedirectResponse: The redirect response from the SSO provider + """ + # Google SSO Auth + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + client_secret=google_client_secret, + redirect_uri=redirect_url, + ) + verbose_proxy_logger.info( + f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + ) + with google_sso: + return await google_sso.get_login_redirect() + # Microsoft SSO Auth + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + with microsoft_sso: + return await microsoft_sso.get_login_redirect() + elif generic_client_id is not None: + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split( + " " + ) + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + with generic_sso: + # TODO: state should be a random string and added to the user session with cookie + # or a cryptographicly signed state that we can verify stateless + # For simplification we are using a static state, this is not perfect but some + # SSO providers do not allow stateless verification + redirect_params = {} + state = os.getenv("GENERIC_CLIENT_STATE", None) + + if state: + redirect_params["state"] = state + elif "okta" in generic_authorization_endpoint: + redirect_params["state"] = ( + uuid.uuid4().hex + ) # set state param for okta - required + return await generic_sso.get_login_redirect(**redirect_params) # type: ignore + raise ValueError( + "Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso" + ) + + @staticmethod + def should_use_sso_handler( + google_client_id: Optional[str] = None, + microsoft_client_id: Optional[str] = None, + generic_client_id: Optional[str] = None, + ) -> bool: + if ( + google_client_id is not None + or microsoft_client_id is not None + or generic_client_id is not None + ): + return True + return False + + @staticmethod + def get_redirect_url_for_sso( + request: Request, + sso_callback_route: str, + ) -> str: + """ + Get the redirect URL for SSO + """ + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + if redirect_url.endswith("/"): + redirect_url += sso_callback_route + else: + redirect_url += "/" + sso_callback_route + return redirect_url + + class MicrosoftSSOHandler: """ Handles Microsoft SSO callback response and returns a CustomOpenID object """ + @staticmethod + async def get_microsoft_callback_response( + request: Request, + microsoft_client_id: str, + redirect_url: str, + jwt_handler: JWTHandler, + return_raw_sso_response: bool = False, + ) -> Union[CustomOpenID, OpenID, dict]: + """ + Get the Microsoft SSO callback response + + Args: + return_raw_sso_response: If True, return the raw SSO response + """ + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + original_msft_result = await microsoft_sso.verify_and_process( + request=request, + convert_response=False, + ) + + # if user is trying to get the raw sso response for debugging, return the raw sso response + if return_raw_sso_response: + return original_msft_result or {} + + result = MicrosoftSSOHandler.openid_from_response( + response=original_msft_result, + jwt_handler=jwt_handler, + ) + return result + @staticmethod def openid_from_response( response: Optional[dict], jwt_handler: JWTHandler @@ -811,3 +912,181 @@ class MicrosoftSSOHandler: ) verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") return openid_response + + +class GoogleSSOHandler: + """ + Handles Google SSO callback response and returns a CustomOpenID object + """ + + @staticmethod + async def get_google_callback_response( + request: Request, + google_client_id: str, + redirect_url: str, + return_raw_sso_response: bool = False, + ) -> Union[OpenID, dict]: + """ + Get the Google SSO callback response + + Args: + return_raw_sso_response: If True, return the raw SSO response + """ + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, + ) + + # if user is trying to get the raw sso response for debugging, return the raw sso response + if return_raw_sso_response: + return ( + await google_sso.verify_and_process( + request=request, + convert_response=False, + ) + or {} + ) + + result = await google_sso.verify_and_process(request) + return result or {} + + +@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False) +async def debug_sso_login(request: Request): + """ + Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env + PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" + Example: + """ + from litellm.proxy.proxy_server import premium_user + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + ####### Check if user is a Enterprise / Premium User ####### + if ( + microsoft_client_id is not None + or google_client_id is not None + or generic_client_id is not None + ): + if premium_user is not True: + raise ProxyException( + message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", + type=ProxyErrorTypes.auth_error, + param="premium_user", + code=status.HTTP_403_FORBIDDEN, + ) + + # get url from request + redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( + request=request, + sso_callback_route="sso/debug/callback", + ) + + # Check if we should use SSO handler + if ( + SSOAuthenticationHandler.should_use_sso_handler( + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, + ) + is True + ): + return await SSOAuthenticationHandler.get_sso_login_redirect( + redirect_url=redirect_url, + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, + ) + + +@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False) +async def debug_sso_callback(request: Request): + """ + Returns the OpenID object returned by the SSO provider + """ + import json + + from fastapi.responses import HTMLResponse + + from litellm.proxy.proxy_server import jwt_handler + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + if redirect_url.endswith("/"): + redirect_url += "sso/debug/callback" + else: + redirect_url += "/sso/debug/callback" + + result = None + if google_client_id is not None: + result = await GoogleSSOHandler.get_google_callback_response( + request=request, + google_client_id=google_client_id, + redirect_url=redirect_url, + return_raw_sso_response=True, + ) + elif microsoft_client_id is not None: + result = await MicrosoftSSOHandler.get_microsoft_callback_response( + request=request, + microsoft_client_id=microsoft_client_id, + redirect_url=redirect_url, + jwt_handler=jwt_handler, + return_raw_sso_response=True, + ) + elif generic_client_id is not None: + result = await get_generic_sso_response( + request=request, + jwt_handler=jwt_handler, + generic_client_id=generic_client_id, + redirect_url=redirect_url, + ) + + # If result is None, return a basic error message + if result is None: + return HTMLResponse( + content="

SSO Authentication Failed

No data was returned from the SSO provider.

", + status_code=400, + ) + + # Convert the OpenID object to a dictionary + if hasattr(result, "__dict__"): + result_dict = result.__dict__ + else: + result_dict = dict(result) + + # Filter out any None values and convert to JSON serializable format + filtered_result = {} + for key, value in result_dict.items(): + if value is not None and not key.startswith("_"): + if isinstance(value, (str, int, float, bool)) or value is None: + filtered_result[key] = value + else: + try: + # Try to convert to string or another JSON serializable format + filtered_result[key] = str(value) + except Exception as e: + filtered_result[key] = f"Complex value (not displayable): {str(e)}" + + # Replace the placeholder in the template with the actual data + html_content = jwt_display_template.replace( + "const userData = SSO_DATA;", + f"const userData = {json.dumps(filtered_result, indent=2)};", + ) + + return HTMLResponse(content=html_content) diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index 7ad520f7d5..b785b01f8c 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -1,3 +1,4 @@ +import asyncio import json import os import sys @@ -5,15 +6,19 @@ from typing import Optional, cast from unittest.mock import MagicMock, patch import pytest +from fastapi import Request from fastapi.testclient import TestClient sys.path.insert( - 0, os.path.abspath("../../..") + 0, os.path.abspath("../../../") ) # Adds the parent directory to the system path from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.types import CustomOpenID -from litellm.proxy.management_endpoints.ui_sso import MicrosoftSSOHandler +from litellm.proxy.management_endpoints.ui_sso import ( + GoogleSSOHandler, + MicrosoftSSOHandler, +) def test_microsoft_sso_handler_openid_from_response(): @@ -79,3 +84,125 @@ def test_microsoft_sso_handler_with_empty_response(): # Make sure the JWT handler was called with an empty dict mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with({}) + + +def test_get_microsoft_callback_response(): + # Arrange + mock_request = MagicMock(spec=Request) + mock_jwt_handler = MagicMock(spec=JWTHandler) + mock_response = { + "mail": "microsoft_user@example.com", + "displayName": "Microsoft User", + "id": "msft123", + "givenName": "Microsoft", + "surname": "User", + } + + future = asyncio.Future() + future.set_result(mock_response) + + with patch.dict( + os.environ, + {"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"}, + ): + with patch( + "fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process", + return_value=future, + ): + # Act + result = asyncio.run( + MicrosoftSSOHandler.get_microsoft_callback_response( + request=mock_request, + microsoft_client_id="mock_client_id", + redirect_url="http://mock_redirect_url", + jwt_handler=mock_jwt_handler, + ) + ) + + # Assert + assert isinstance(result, CustomOpenID) + assert result.email == "microsoft_user@example.com" + assert result.display_name == "Microsoft User" + assert result.provider == "microsoft" + assert result.id == "msft123" + assert result.first_name == "Microsoft" + assert result.last_name == "User" + + +def test_get_microsoft_callback_response_raw_sso_response(): + # Arrange + mock_request = MagicMock(spec=Request) + mock_jwt_handler = MagicMock(spec=JWTHandler) + mock_response = { + "mail": "microsoft_user@example.com", + "displayName": "Microsoft User", + "id": "msft123", + "givenName": "Microsoft", + "surname": "User", + } + + future = asyncio.Future() + future.set_result(mock_response) + with patch.dict( + os.environ, + {"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"}, + ): + with patch( + "fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process", + return_value=future, + ): + # Act + result = asyncio.run( + MicrosoftSSOHandler.get_microsoft_callback_response( + request=mock_request, + microsoft_client_id="mock_client_id", + redirect_url="http://mock_redirect_url", + jwt_handler=mock_jwt_handler, + return_raw_sso_response=True, + ) + ) + + # Assert + print("result from verify_and_process", result) + assert isinstance(result, dict) + assert result["mail"] == "microsoft_user@example.com" + assert result["displayName"] == "Microsoft User" + assert result["id"] == "msft123" + assert result["givenName"] == "Microsoft" + assert result["surname"] == "User" + + +def test_get_google_callback_response(): + # Arrange + mock_request = MagicMock(spec=Request) + mock_response = { + "email": "google_user@example.com", + "name": "Google User", + "sub": "google123", + "given_name": "Google", + "family_name": "User", + } + + future = asyncio.Future() + future.set_result(mock_response) + + with patch.dict(os.environ, {"GOOGLE_CLIENT_SECRET": "mock_secret"}): + with patch( + "fastapi_sso.sso.google.GoogleSSO.verify_and_process", return_value=future + ): + # Act + result = asyncio.run( + GoogleSSOHandler.get_google_callback_response( + request=mock_request, + google_client_id="mock_client_id", + redirect_url="http://mock_redirect_url", + ) + ) + + # Assert + assert isinstance(result, dict) + assert result.get("email") == "google_user@example.com" + assert result.get("name") == "Google User" + assert result.get("sub") == "google123" + assert result.get("given_name") == "Google" + assert result.get("family_name") == "User"