mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
refactor SSO handler
This commit is contained in:
parent
73356b3a9f
commit
89cf042541
1 changed files with 368 additions and 167 deletions
|
@ -3,6 +3,9 @@ Has all /sso/* routes
|
||||||
|
|
||||||
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
|
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
|
||||||
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
|
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
|
||||||
|
|
||||||
|
/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback
|
||||||
|
/sso/debug/callback - returns the OpenID object returned by the SSO provider
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -92,131 +95,29 @@ async def google_login(request: Request): # noqa: PLR0915
|
||||||
missing_env_vars = show_missing_vars_in_env()
|
missing_env_vars = show_missing_vars_in_env()
|
||||||
if missing_env_vars is not None:
|
if missing_env_vars is not None:
|
||||||
return missing_env_vars
|
return missing_env_vars
|
||||||
|
ui_username = os.getenv("UI_USERNAME")
|
||||||
|
|
||||||
# get url from request
|
# get url from request
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||||||
ui_username = os.getenv("UI_USERNAME")
|
request=request,
|
||||||
if redirect_url.endswith("/"):
|
sso_callback_route="sso/callback",
|
||||||
redirect_url += "sso/callback"
|
)
|
||||||
else:
|
|
||||||
redirect_url += "/sso/callback"
|
|
||||||
# Google SSO Auth
|
|
||||||
if google_client_id is not None:
|
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
|
||||||
|
|
||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
# Check if we should use SSO handler
|
||||||
if google_client_secret is None:
|
if (
|
||||||
raise ProxyException(
|
SSOAuthenticationHandler.should_use_sso_handler(
|
||||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
microsoft_client_id=microsoft_client_id,
|
||||||
type=ProxyErrorTypes.auth_error,
|
google_client_id=google_client_id,
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
generic_client_id=generic_client_id,
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
google_sso = GoogleSSO(
|
|
||||||
client_id=google_client_id,
|
|
||||||
client_secret=google_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.info(
|
is True
|
||||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
):
|
||||||
|
return SSOAuthenticationHandler.get_sso_login_redirect(
|
||||||
|
redirect_url=redirect_url,
|
||||||
|
microsoft_client_id=microsoft_client_id,
|
||||||
|
google_client_id=google_client_id,
|
||||||
|
generic_client_id=generic_client_id,
|
||||||
)
|
)
|
||||||
with google_sso:
|
|
||||||
return await google_sso.get_login_redirect()
|
|
||||||
# Microsoft SSO Auth
|
|
||||||
elif microsoft_client_id is not None:
|
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
||||||
|
|
||||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
||||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
||||||
if microsoft_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
|
||||||
client_id=microsoft_client_id,
|
|
||||||
client_secret=microsoft_client_secret,
|
|
||||||
tenant=microsoft_tenant,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
with microsoft_sso:
|
|
||||||
return await microsoft_sso.get_login_redirect()
|
|
||||||
elif generic_client_id is not None:
|
|
||||||
from fastapi_sso.sso.base import DiscoveryDocument
|
|
||||||
from fastapi_sso.sso.generic import create_provider
|
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
|
||||||
generic_authorization_endpoint = os.getenv(
|
|
||||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
||||||
)
|
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
||||||
if generic_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_authorization_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_token_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_TOKEN_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_userinfo_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
||||||
)
|
|
||||||
discovery = DiscoveryDocument(
|
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
|
||||||
token_endpoint=generic_token_endpoint,
|
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
|
||||||
)
|
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
||||||
generic_sso = SSOProvider(
|
|
||||||
client_id=generic_client_id,
|
|
||||||
client_secret=generic_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
scope=generic_scope,
|
|
||||||
)
|
|
||||||
with generic_sso:
|
|
||||||
# TODO: state should be a random string and added to the user session with cookie
|
|
||||||
# or a cryptographicly signed state that we can verify stateless
|
|
||||||
# For simplification we are using a static state, this is not perfect but some
|
|
||||||
# SSO providers do not allow stateless verification
|
|
||||||
redirect_params = {}
|
|
||||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
|
||||||
|
|
||||||
if state:
|
|
||||||
redirect_params["state"] = state
|
|
||||||
elif "okta" in generic_authorization_endpoint:
|
|
||||||
redirect_params[
|
|
||||||
"state"
|
|
||||||
] = uuid.uuid4().hex # set state param for okta - required
|
|
||||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
|
||||||
elif ui_username is not None:
|
elif ui_username is not None:
|
||||||
# No Google, Microsoft SSO
|
# No Google, Microsoft SSO
|
||||||
# Use UI Credentials set in .env
|
# Use UI Credentials set in .env
|
||||||
|
@ -407,7 +308,7 @@ def get_disabled_non_admin_personal_key_creation():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||||
async def auth_callback(request: Request): # noqa: PLR0915
|
async def auth_callback(request: Request):
|
||||||
"""Verify login"""
|
"""Verify login"""
|
||||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
generate_key_helper_fn,
|
generate_key_helper_fn,
|
||||||
|
@ -443,54 +344,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
if google_client_id is not None:
|
if google_client_id is not None:
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
result = await GoogleSSOHandler.get_google_callback_response(
|
||||||
|
|
||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
||||||
if google_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
google_sso = GoogleSSO(
|
|
||||||
client_id=google_client_id,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
client_secret=google_client_secret,
|
|
||||||
)
|
|
||||||
result = await google_sso.verify_and_process(request)
|
|
||||||
elif microsoft_client_id is not None:
|
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
||||||
|
|
||||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
||||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
||||||
if microsoft_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if microsoft_tenant is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="MICROSOFT_TENANT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
|
||||||
client_id=microsoft_client_id,
|
|
||||||
client_secret=microsoft_client_secret,
|
|
||||||
tenant=microsoft_tenant,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
original_msft_result = await microsoft_sso.verify_and_process(
|
|
||||||
request=request,
|
request=request,
|
||||||
convert_response=False,
|
google_client_id=google_client_id,
|
||||||
|
redirect_url=redirect_url,
|
||||||
)
|
)
|
||||||
result = MicrosoftSSOHandler.openid_from_response(
|
elif microsoft_client_id is not None:
|
||||||
response=original_msft_result,
|
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||||
|
request=request,
|
||||||
|
microsoft_client_id=microsoft_client_id,
|
||||||
|
redirect_url=redirect_url,
|
||||||
jwt_handler=jwt_handler,
|
jwt_handler=jwt_handler,
|
||||||
)
|
)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
|
@ -789,11 +652,227 @@ async def get_ui_settings(request: Request):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SSOAuthenticationHandler:
|
||||||
|
"""
|
||||||
|
Handler for SSO Authentication across all SSO providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_sso_login_redirect(
|
||||||
|
redirect_url: str,
|
||||||
|
google_client_id: Optional[str] = None,
|
||||||
|
microsoft_client_id: Optional[str] = None,
|
||||||
|
generic_client_id: Optional[str] = None,
|
||||||
|
) -> Optional[RedirectResponse]:
|
||||||
|
"""
|
||||||
|
Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_url (str): The URL to redirect the user to after login
|
||||||
|
google_client_id (Optional[str], optional): The Google Client ID. Defaults to None.
|
||||||
|
microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None.
|
||||||
|
generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedirectResponse: The redirect response from the SSO provider
|
||||||
|
"""
|
||||||
|
# Google SSO Auth
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||||
|
)
|
||||||
|
with google_sso:
|
||||||
|
return await google_sso.get_login_redirect()
|
||||||
|
# Microsoft SSO Auth
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
with microsoft_sso:
|
||||||
|
return await microsoft_sso.get_login_redirect()
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
from fastapi_sso.sso.base import DiscoveryDocument
|
||||||
|
from fastapi_sso.sso.generic import create_provider
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(
|
||||||
|
" "
|
||||||
|
)
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
scope=generic_scope,
|
||||||
|
)
|
||||||
|
with generic_sso:
|
||||||
|
# TODO: state should be a random string and added to the user session with cookie
|
||||||
|
# or a cryptographicly signed state that we can verify stateless
|
||||||
|
# For simplification we are using a static state, this is not perfect but some
|
||||||
|
# SSO providers do not allow stateless verification
|
||||||
|
redirect_params = {}
|
||||||
|
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||||
|
|
||||||
|
if state:
|
||||||
|
redirect_params["state"] = state
|
||||||
|
elif "okta" in generic_authorization_endpoint:
|
||||||
|
redirect_params[
|
||||||
|
"state"
|
||||||
|
] = uuid.uuid4().hex # set state param for okta - required
|
||||||
|
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_use_sso_handler(
|
||||||
|
google_client_id: Optional[str] = None,
|
||||||
|
microsoft_client_id: Optional[str] = None,
|
||||||
|
generic_client_id: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
if (
|
||||||
|
google_client_id is not None
|
||||||
|
or microsoft_client_id is not None
|
||||||
|
or generic_client_id is not None
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_redirect_url_for_sso(
|
||||||
|
request: Request,
|
||||||
|
sso_callback_route: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the redirect URL for SSO
|
||||||
|
"""
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += sso_callback_route
|
||||||
|
else:
|
||||||
|
redirect_url += "/" + sso_callback_route
|
||||||
|
return redirect_url
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftSSOHandler:
|
class MicrosoftSSOHandler:
|
||||||
"""
|
"""
|
||||||
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_microsoft_callback_response(
|
||||||
|
request: Request,
|
||||||
|
microsoft_client_id: str,
|
||||||
|
redirect_url: str,
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
) -> CustomOpenID:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if microsoft_tenant is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="MICROSOFT_TENANT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
original_msft_result = await microsoft_sso.verify_and_process(
|
||||||
|
request=request,
|
||||||
|
convert_response=False,
|
||||||
|
)
|
||||||
|
result = MicrosoftSSOHandler.openid_from_response(
|
||||||
|
response=original_msft_result,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def openid_from_response(
|
def openid_from_response(
|
||||||
response: Optional[dict], jwt_handler: JWTHandler
|
response: Optional[dict], jwt_handler: JWTHandler
|
||||||
|
@ -811,3 +890,125 @@ class MicrosoftSSOHandler:
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
||||||
return openid_response
|
return openid_response
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSSOHandler:
|
||||||
|
"""
|
||||||
|
Handles Google SSO callback response and returns a CustomOpenID object
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_google_callback_response(
|
||||||
|
request: Request,
|
||||||
|
google_client_id: str,
|
||||||
|
redirect_url: str,
|
||||||
|
) -> Optional[OpenID]:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
)
|
||||||
|
result = await google_sso.verify_and_process(request)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False)
|
||||||
|
async def debug_sso_login(request: Request):
|
||||||
|
"""
|
||||||
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||||||
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||||
|
Example:
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
|
|
||||||
|
####### Check if user is a Enterprise / Premium User #######
|
||||||
|
if (
|
||||||
|
microsoft_client_id is not None
|
||||||
|
or google_client_id is not None
|
||||||
|
or generic_client_id is not None
|
||||||
|
):
|
||||||
|
if premium_user is not True:
|
||||||
|
raise ProxyException(
|
||||||
|
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="premium_user",
|
||||||
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get url from request
|
||||||
|
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||||||
|
request=request,
|
||||||
|
sso_callback_route="sso/debug/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we should use SSO handler
|
||||||
|
if (
|
||||||
|
SSOAuthenticationHandler.should_use_sso_handler(
|
||||||
|
microsoft_client_id=microsoft_client_id,
|
||||||
|
google_client_id=google_client_id,
|
||||||
|
generic_client_id=generic_client_id,
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||||||
|
redirect_url=redirect_url,
|
||||||
|
microsoft_client_id=microsoft_client_id,
|
||||||
|
google_client_id=google_client_id,
|
||||||
|
generic_client_id=generic_client_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False)
|
||||||
|
async def debug_sso_callback(request: Request):
|
||||||
|
"""
|
||||||
|
Returns the OpenID object returned by the SSO provider
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import jwt_handler
|
||||||
|
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
|
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/debug/callback"
|
||||||
|
else:
|
||||||
|
redirect_url += "/sso/debug/callback"
|
||||||
|
|
||||||
|
result = None
|
||||||
|
if google_client_id is not None:
|
||||||
|
result = await GoogleSSOHandler.get_google_callback_response(
|
||||||
|
request=request,
|
||||||
|
google_client_id=google_client_id,
|
||||||
|
redirect_url=redirect_url,
|
||||||
|
)
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||||
|
request=request,
|
||||||
|
microsoft_client_id=microsoft_client_id,
|
||||||
|
redirect_url=redirect_url,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
)
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
result = await get_generic_sso_response(
|
||||||
|
request=request,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
generic_client_id=generic_client_id,
|
||||||
|
redirect_url=redirect_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue