diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 970587ded9..5cf0336445 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -3,6 +3,9 @@ Has all /sso/* routes /sso/key/generate - handles user signing in with SSO and redirects to /sso/callback /sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI + +/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback +/sso/debug/callback - returns the OpenID object returned by the SSO provider """ import asyncio @@ -92,131 +95,29 @@ async def google_login(request: Request): # noqa: PLR0915 missing_env_vars = show_missing_vars_in_env() if missing_env_vars is not None: return missing_env_vars + ui_username = os.getenv("UI_USERNAME") # get url from request - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - ui_username = os.getenv("UI_USERNAME") - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - # Google SSO Auth - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO + redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( + request=request, + sso_callback_route="sso/callback", + ) - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - client_secret=google_client_secret, - redirect_uri=redirect_url, + # Check if we should use SSO handler + if ( + SSOAuthenticationHandler.should_use_sso_handler( + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, ) - verbose_proxy_logger.info( - f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + is True + ): + return SSOAuthenticationHandler.get_sso_login_redirect( + redirect_url=redirect_url, + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, ) - with google_sso: - return await google_sso.get_login_redirect() - # Microsoft SSO Auth - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - with microsoft_sso: - return await microsoft_sso.get_login_redirect() - elif generic_client_id is not None: - from fastapi_sso.sso.base import DiscoveryDocument - from fastapi_sso.sso.generic import create_provider - - generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) - generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") - generic_authorization_endpoint = os.getenv( - "GENERIC_AUTHORIZATION_ENDPOINT", None - ) - generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) - generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) - if generic_client_secret is None: - raise ProxyException( - message="GENERIC_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_authorization_endpoint is None: - raise ProxyException( - message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_AUTHORIZATION_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_token_endpoint is None: - raise ProxyException( - message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_TOKEN_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_userinfo_endpoint is None: - raise ProxyException( - message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_USERINFO_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - verbose_proxy_logger.debug( - f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" - ) - verbose_proxy_logger.debug( - f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" - ) - discovery = DiscoveryDocument( - authorization_endpoint=generic_authorization_endpoint, - token_endpoint=generic_token_endpoint, - userinfo_endpoint=generic_userinfo_endpoint, - ) - SSOProvider = create_provider(name="oidc", discovery_document=discovery) - generic_sso = SSOProvider( - client_id=generic_client_id, - client_secret=generic_client_secret, - redirect_uri=redirect_url, - allow_insecure_http=True, - scope=generic_scope, - ) - with generic_sso: - # TODO: state should be a random string and added to the user session with cookie - # or a cryptographicly signed state that we can verify stateless - # For simplification we are using a static state, this is not perfect but some - # SSO providers do not allow stateless verification - redirect_params = {} - state = os.getenv("GENERIC_CLIENT_STATE", None) - - if state: - redirect_params["state"] = state - elif "okta" in generic_authorization_endpoint: - redirect_params[ - "state" - ] = uuid.uuid4().hex # set state param for okta - required - return await generic_sso.get_login_redirect(**redirect_params) # type: ignore elif ui_username is not None: # No Google, Microsoft SSO # Use UI Credentials set in .env @@ -407,7 +308,7 @@ def get_disabled_non_admin_personal_key_creation(): @router.get("/sso/callback", tags=["experimental"], include_in_schema=False) -async def auth_callback(request: Request): # noqa: PLR0915 +async def auth_callback(request: Request): """Verify login""" from litellm.proxy.management_endpoints.key_management_endpoints import ( generate_key_helper_fn, @@ -443,54 +344,16 @@ async def auth_callback(request: Request): # noqa: PLR0915 result = None if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - redirect_uri=redirect_url, - client_secret=google_client_secret, - ) - result = await google_sso.verify_and_process(request) - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if microsoft_tenant is None: - raise ProxyException( - message="MICROSOFT_TENANT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_TENANT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - original_msft_result = await microsoft_sso.verify_and_process( + result = await GoogleSSOHandler.get_google_callback_response( request=request, - convert_response=False, + google_client_id=google_client_id, + redirect_url=redirect_url, ) - result = MicrosoftSSOHandler.openid_from_response( - response=original_msft_result, + elif microsoft_client_id is not None: + result = await MicrosoftSSOHandler.get_microsoft_callback_response( + request=request, + microsoft_client_id=microsoft_client_id, + redirect_url=redirect_url, jwt_handler=jwt_handler, ) elif generic_client_id is not None: @@ -789,11 +652,227 @@ async def get_ui_settings(request: Request): } +class SSOAuthenticationHandler: + """ + Handler for SSO Authentication across all SSO providers + """ + + @staticmethod + async def get_sso_login_redirect( + redirect_url: str, + google_client_id: Optional[str] = None, + microsoft_client_id: Optional[str] = None, + generic_client_id: Optional[str] = None, + ) -> Optional[RedirectResponse]: + """ + Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url` + + Args: + redirect_url (str): The URL to redirect the user to after login + google_client_id (Optional[str], optional): The Google Client ID. Defaults to None. + microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None. + generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None. + + Returns: + RedirectResponse: The redirect response from the SSO provider + """ + # Google SSO Auth + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + client_secret=google_client_secret, + redirect_uri=redirect_url, + ) + verbose_proxy_logger.info( + f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + ) + with google_sso: + return await google_sso.get_login_redirect() + # Microsoft SSO Auth + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + with microsoft_sso: + return await microsoft_sso.get_login_redirect() + elif generic_client_id is not None: + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split( + " " + ) + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + with generic_sso: + # TODO: state should be a random string and added to the user session with cookie + # or a cryptographicly signed state that we can verify stateless + # For simplification we are using a static state, this is not perfect but some + # SSO providers do not allow stateless verification + redirect_params = {} + state = os.getenv("GENERIC_CLIENT_STATE", None) + + if state: + redirect_params["state"] = state + elif "okta" in generic_authorization_endpoint: + redirect_params[ + "state" + ] = uuid.uuid4().hex # set state param for okta - required + return await generic_sso.get_login_redirect(**redirect_params) # type: ignore + + @staticmethod + def should_use_sso_handler( + google_client_id: Optional[str] = None, + microsoft_client_id: Optional[str] = None, + generic_client_id: Optional[str] = None, + ) -> bool: + if ( + google_client_id is not None + or microsoft_client_id is not None + or generic_client_id is not None + ): + return True + return False + + @staticmethod + def get_redirect_url_for_sso( + request: Request, + sso_callback_route: str, + ) -> str: + """ + Get the redirect URL for SSO + """ + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + if redirect_url.endswith("/"): + redirect_url += sso_callback_route + else: + redirect_url += "/" + sso_callback_route + return redirect_url + + class MicrosoftSSOHandler: """ Handles Microsoft SSO callback response and returns a CustomOpenID object """ + @staticmethod + async def get_microsoft_callback_response( + request: Request, + microsoft_client_id: str, + redirect_url: str, + jwt_handler: JWTHandler, + ) -> CustomOpenID: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + original_msft_result = await microsoft_sso.verify_and_process( + request=request, + convert_response=False, + ) + result = MicrosoftSSOHandler.openid_from_response( + response=original_msft_result, + jwt_handler=jwt_handler, + ) + return result + @staticmethod def openid_from_response( response: Optional[dict], jwt_handler: JWTHandler @@ -811,3 +890,125 @@ class MicrosoftSSOHandler: ) verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") return openid_response + + +class GoogleSSOHandler: + """ + Handles Google SSO callback response and returns a CustomOpenID object + """ + + @staticmethod + async def get_google_callback_response( + request: Request, + google_client_id: str, + redirect_url: str, + ) -> Optional[OpenID]: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, + ) + result = await google_sso.verify_and_process(request) + return result + + +@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False) +async def debug_sso_login(request: Request): + """ + Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env + PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" + Example: + """ + from litellm.proxy.proxy_server import premium_user + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + ####### Check if user is a Enterprise / Premium User ####### + if ( + microsoft_client_id is not None + or google_client_id is not None + or generic_client_id is not None + ): + if premium_user is not True: + raise ProxyException( + message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", + type=ProxyErrorTypes.auth_error, + param="premium_user", + code=status.HTTP_403_FORBIDDEN, + ) + + # get url from request + redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( + request=request, + sso_callback_route="sso/debug/callback", + ) + + # Check if we should use SSO handler + if ( + SSOAuthenticationHandler.should_use_sso_handler( + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, + ) + is True + ): + return await SSOAuthenticationHandler.get_sso_login_redirect( + redirect_url=redirect_url, + microsoft_client_id=microsoft_client_id, + google_client_id=google_client_id, + generic_client_id=generic_client_id, + ) + + +@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False) +async def debug_sso_callback(request: Request): + """ + Returns the OpenID object returned by the SSO provider + """ + from litellm.proxy.proxy_server import jwt_handler + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + if redirect_url.endswith("/"): + redirect_url += "sso/debug/callback" + else: + redirect_url += "/sso/debug/callback" + + result = None + if google_client_id is not None: + result = await GoogleSSOHandler.get_google_callback_response( + request=request, + google_client_id=google_client_id, + redirect_url=redirect_url, + ) + elif microsoft_client_id is not None: + result = await MicrosoftSSOHandler.get_microsoft_callback_response( + request=request, + microsoft_client_id=microsoft_client_id, + redirect_url=redirect_url, + jwt_handler=jwt_handler, + ) + elif generic_client_id is not None: + result = await get_generic_sso_response( + request=request, + jwt_handler=jwt_handler, + generic_client_id=generic_client_id, + redirect_url=redirect_url, + ) + + return result