refactor SSO handler

This commit is contained in:
Ishaan Jaff 2025-04-08 15:20:50 -07:00
parent 73356b3a9f
commit 89cf042541

View file

@ -3,6 +3,9 @@ Has all /sso/* routes
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback /sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI /sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback
/sso/debug/callback - returns the OpenID object returned by the SSO provider
""" """
import asyncio import asyncio
@ -92,131 +95,29 @@ async def google_login(request: Request): # noqa: PLR0915
missing_env_vars = show_missing_vars_in_env() missing_env_vars = show_missing_vars_in_env()
if missing_env_vars is not None: if missing_env_vars is not None:
return missing_env_vars return missing_env_vars
ui_username = os.getenv("UI_USERNAME")
# get url from request # get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
ui_username = os.getenv("UI_USERNAME") request=request,
if redirect_url.endswith("/"): sso_callback_route="sso/callback",
redirect_url += "sso/callback" )
else:
redirect_url += "/sso/callback"
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) # Check if we should use SSO handler
if google_client_secret is None: if (
raise ProxyException( SSOAuthenticationHandler.should_use_sso_handler(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", microsoft_client_id=microsoft_client_id,
type=ProxyErrorTypes.auth_error, google_client_id=google_client_id,
param="GOOGLE_CLIENT_SECRET", generic_client_id=generic_client_id,
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
) )
verbose_proxy_logger.info( is True
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" ):
return SSOAuthenticationHandler.get_sso_login_redirect(
redirect_url=redirect_url,
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
) )
with google_sso:
return await google_sso.get_login_redirect()
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
with generic_sso:
# TODO: state should be a random string and added to the user session with cookie
# or a cryptographicly signed state that we can verify stateless
# For simplification we are using a static state, this is not perfect but some
# SSO providers do not allow stateless verification
redirect_params = {}
state = os.getenv("GENERIC_CLIENT_STATE", None)
if state:
redirect_params["state"] = state
elif "okta" in generic_authorization_endpoint:
redirect_params[
"state"
] = uuid.uuid4().hex # set state param for okta - required
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
elif ui_username is not None: elif ui_username is not None:
# No Google, Microsoft SSO # No Google, Microsoft SSO
# Use UI Credentials set in .env # Use UI Credentials set in .env
@ -407,7 +308,7 @@ def get_disabled_non_admin_personal_key_creation():
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False) @router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
async def auth_callback(request: Request): # noqa: PLR0915 async def auth_callback(request: Request):
"""Verify login""" """Verify login"""
from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn, generate_key_helper_fn,
@ -443,54 +344,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
result = None result = None
if google_client_id is not None: if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO result = await GoogleSSOHandler.get_google_callback_response(
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
original_msft_result = await microsoft_sso.verify_and_process(
request=request, request=request,
convert_response=False, google_client_id=google_client_id,
redirect_url=redirect_url,
) )
result = MicrosoftSSOHandler.openid_from_response( elif microsoft_client_id is not None:
response=original_msft_result, result = await MicrosoftSSOHandler.get_microsoft_callback_response(
request=request,
microsoft_client_id=microsoft_client_id,
redirect_url=redirect_url,
jwt_handler=jwt_handler, jwt_handler=jwt_handler,
) )
elif generic_client_id is not None: elif generic_client_id is not None:
@ -789,11 +652,227 @@ async def get_ui_settings(request: Request):
} }
class SSOAuthenticationHandler:
"""
Handler for SSO Authentication across all SSO providers
"""
@staticmethod
async def get_sso_login_redirect(
redirect_url: str,
google_client_id: Optional[str] = None,
microsoft_client_id: Optional[str] = None,
generic_client_id: Optional[str] = None,
) -> Optional[RedirectResponse]:
"""
Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url`
Args:
redirect_url (str): The URL to redirect the user to after login
google_client_id (Optional[str], optional): The Google Client ID. Defaults to None.
microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None.
generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None.
Returns:
RedirectResponse: The redirect response from the SSO provider
"""
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
with google_sso:
return await google_sso.get_login_redirect()
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(
" "
)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
with generic_sso:
# TODO: state should be a random string and added to the user session with cookie
# or a cryptographicly signed state that we can verify stateless
# For simplification we are using a static state, this is not perfect but some
# SSO providers do not allow stateless verification
redirect_params = {}
state = os.getenv("GENERIC_CLIENT_STATE", None)
if state:
redirect_params["state"] = state
elif "okta" in generic_authorization_endpoint:
redirect_params[
"state"
] = uuid.uuid4().hex # set state param for okta - required
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
@staticmethod
def should_use_sso_handler(
google_client_id: Optional[str] = None,
microsoft_client_id: Optional[str] = None,
generic_client_id: Optional[str] = None,
) -> bool:
if (
google_client_id is not None
or microsoft_client_id is not None
or generic_client_id is not None
):
return True
return False
@staticmethod
def get_redirect_url_for_sso(
request: Request,
sso_callback_route: str,
) -> str:
"""
Get the redirect URL for SSO
"""
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += sso_callback_route
else:
redirect_url += "/" + sso_callback_route
return redirect_url
class MicrosoftSSOHandler: class MicrosoftSSOHandler:
""" """
Handles Microsoft SSO callback response and returns a CustomOpenID object Handles Microsoft SSO callback response and returns a CustomOpenID object
""" """
@staticmethod
async def get_microsoft_callback_response(
request: Request,
microsoft_client_id: str,
redirect_url: str,
jwt_handler: JWTHandler,
) -> CustomOpenID:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
original_msft_result = await microsoft_sso.verify_and_process(
request=request,
convert_response=False,
)
result = MicrosoftSSOHandler.openid_from_response(
response=original_msft_result,
jwt_handler=jwt_handler,
)
return result
@staticmethod @staticmethod
def openid_from_response( def openid_from_response(
response: Optional[dict], jwt_handler: JWTHandler response: Optional[dict], jwt_handler: JWTHandler
@ -811,3 +890,125 @@ class MicrosoftSSOHandler:
) )
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
return openid_response return openid_response
class GoogleSSOHandler:
"""
Handles Google SSO callback response and returns a CustomOpenID object
"""
@staticmethod
async def get_google_callback_response(
request: Request,
google_client_id: str,
redirect_url: str,
) -> Optional[OpenID]:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
return result
@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False)
async def debug_sso_login(request: Request):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
from litellm.proxy.proxy_server import premium_user
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
####### Check if user is a Enterprise / Premium User #######
if (
microsoft_client_id is not None
or google_client_id is not None
or generic_client_id is not None
):
if premium_user is not True:
raise ProxyException(
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
type=ProxyErrorTypes.auth_error,
param="premium_user",
code=status.HTTP_403_FORBIDDEN,
)
# get url from request
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
request=request,
sso_callback_route="sso/debug/callback",
)
# Check if we should use SSO handler
if (
SSOAuthenticationHandler.should_use_sso_handler(
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
)
is True
):
return await SSOAuthenticationHandler.get_sso_login_redirect(
redirect_url=redirect_url,
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
)
@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False)
async def debug_sso_callback(request: Request):
"""
Returns the OpenID object returned by the SSO provider
"""
from litellm.proxy.proxy_server import jwt_handler
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += "sso/debug/callback"
else:
redirect_url += "/sso/debug/callback"
result = None
if google_client_id is not None:
result = await GoogleSSOHandler.get_google_callback_response(
request=request,
google_client_id=google_client_id,
redirect_url=redirect_url,
)
elif microsoft_client_id is not None:
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
request=request,
microsoft_client_id=microsoft_client_id,
redirect_url=redirect_url,
jwt_handler=jwt_handler,
)
elif generic_client_id is not None:
result = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
)
return result