mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Feat SSO] Debug route - allow admins to debug SSO JWT fields (#9835)
* refactor SSO handler * render sso JWT on ui * docs debug sso * fix sso login flow use await * fix ui sso debug JWT * test ui sso * remove redis vl * fix redisvl==0.5.1 * fix ml dtypes * fix redisvl * fix redis vl * fix debug_sso_callback * fix linting error * fix redis semantic caching dep
This commit is contained in:
parent
a8da15c0f6
commit
3484100aed
5 changed files with 900 additions and 175 deletions
|
@ -3,6 +3,9 @@ Has all /sso/* routes
|
|||
|
||||
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
|
||||
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
|
||||
|
||||
/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback
|
||||
/sso/debug/callback - returns the OpenID object returned by the SSO provider
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
@ -36,6 +39,9 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
admin_ui_disabled,
|
||||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.common_utils.html_forms.jwt_display_template import (
|
||||
jwt_display_template,
|
||||
)
|
||||
from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||
|
@ -92,131 +98,29 @@ async def google_login(request: Request): # noqa: PLR0915
|
|||
missing_env_vars = show_missing_vars_in_env()
|
||||
if missing_env_vars is not None:
|
||||
return missing_env_vars
|
||||
ui_username = os.getenv("UI_USERNAME")
|
||||
|
||||
# get url from request
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
ui_username = os.getenv("UI_USERNAME")
|
||||
if redirect_url.endswith("/"):
|
||||
redirect_url += "sso/callback"
|
||||
else:
|
||||
redirect_url += "/sso/callback"
|
||||
# Google SSO Auth
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||||
request=request,
|
||||
sso_callback_route="sso/callback",
|
||||
)
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
client_secret=google_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
# Check if we should use SSO handler
|
||||
if (
|
||||
SSOAuthenticationHandler.should_use_sso_handler(
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||
is True
|
||||
):
|
||||
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||||
redirect_url=redirect_url,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
with google_sso:
|
||||
return await google_sso.get_login_redirect()
|
||||
# Microsoft SSO Auth
|
||||
elif microsoft_client_id is not None:
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
scope=generic_scope,
|
||||
)
|
||||
with generic_sso:
|
||||
# TODO: state should be a random string and added to the user session with cookie
|
||||
# or a cryptographicly signed state that we can verify stateless
|
||||
# For simplification we are using a static state, this is not perfect but some
|
||||
# SSO providers do not allow stateless verification
|
||||
redirect_params = {}
|
||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||
|
||||
if state:
|
||||
redirect_params["state"] = state
|
||||
elif "okta" in generic_authorization_endpoint:
|
||||
redirect_params[
|
||||
"state"
|
||||
] = uuid.uuid4().hex # set state param for okta - required
|
||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||
elif ui_username is not None:
|
||||
# No Google, Microsoft SSO
|
||||
# Use UI Credentials set in .env
|
||||
|
@ -271,7 +175,7 @@ async def get_generic_sso_response(
|
|||
jwt_handler: JWTHandler,
|
||||
generic_client_id: str,
|
||||
redirect_url: str,
|
||||
) -> Optional[OpenID]:
|
||||
) -> Union[OpenID, dict]:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
@ -348,7 +252,7 @@ async def get_generic_sso_response(
|
|||
request, params={"include_client_id": generic_include_client_id}
|
||||
)
|
||||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
return result
|
||||
return result or {}
|
||||
|
||||
|
||||
async def create_team_member_add_task(team_id, user_info):
|
||||
|
@ -443,54 +347,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
|
||||
result = None
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
redirect_uri=redirect_url,
|
||||
client_secret=google_client_secret,
|
||||
)
|
||||
result = await google_sso.verify_and_process(request)
|
||||
elif microsoft_client_id is not None:
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if microsoft_tenant is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_TENANT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
original_msft_result = await microsoft_sso.verify_and_process(
|
||||
result = await GoogleSSOHandler.get_google_callback_response(
|
||||
request=request,
|
||||
convert_response=False,
|
||||
google_client_id=google_client_id,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
response=original_msft_result,
|
||||
elif microsoft_client_id is not None:
|
||||
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||
request=request,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
redirect_url=redirect_url,
|
||||
jwt_handler=jwt_handler,
|
||||
)
|
||||
elif generic_client_id is not None:
|
||||
|
@ -705,7 +571,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
|
||||
|
||||
async def insert_sso_user(
|
||||
result_openid: Optional[OpenID],
|
||||
result_openid: Optional[Union[OpenID, dict]],
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||||
) -> NewUserResponse:
|
||||
"""
|
||||
|
@ -721,6 +587,10 @@ async def insert_sso_user(
|
|||
verbose_proxy_logger.debug(
|
||||
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||||
)
|
||||
if result_openid is None:
|
||||
raise ValueError("result_openid is None")
|
||||
if isinstance(result_openid, dict):
|
||||
result_openid = OpenID(**result_openid)
|
||||
|
||||
if user_defined_values is None:
|
||||
raise ValueError("user_defined_values is None")
|
||||
|
@ -733,9 +603,9 @@ async def insert_sso_user(
|
|||
if user_defined_values.get("max_budget") is None:
|
||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||
if user_defined_values.get("budget_duration") is None:
|
||||
user_defined_values[
|
||||
"budget_duration"
|
||||
] = litellm.internal_user_budget_duration
|
||||
user_defined_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
|
||||
if user_defined_values["user_role"] is None:
|
||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
@ -789,11 +659,242 @@ async def get_ui_settings(request: Request):
|
|||
}
|
||||
|
||||
|
||||
class SSOAuthenticationHandler:
|
||||
"""
|
||||
Handler for SSO Authentication across all SSO providers
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_sso_login_redirect(
|
||||
redirect_url: str,
|
||||
google_client_id: Optional[str] = None,
|
||||
microsoft_client_id: Optional[str] = None,
|
||||
generic_client_id: Optional[str] = None,
|
||||
) -> Optional[RedirectResponse]:
|
||||
"""
|
||||
Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url`
|
||||
|
||||
Args:
|
||||
redirect_url (str): The URL to redirect the user to after login
|
||||
google_client_id (Optional[str], optional): The Google Client ID. Defaults to None.
|
||||
microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None.
|
||||
generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None.
|
||||
|
||||
Returns:
|
||||
RedirectResponse: The redirect response from the SSO provider
|
||||
"""
|
||||
# Google SSO Auth
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
client_secret=google_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||
)
|
||||
with google_sso:
|
||||
return await google_sso.get_login_redirect()
|
||||
# Microsoft SSO Auth
|
||||
elif microsoft_client_id is not None:
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(
|
||||
" "
|
||||
)
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
scope=generic_scope,
|
||||
)
|
||||
with generic_sso:
|
||||
# TODO: state should be a random string and added to the user session with cookie
|
||||
# or a cryptographicly signed state that we can verify stateless
|
||||
# For simplification we are using a static state, this is not perfect but some
|
||||
# SSO providers do not allow stateless verification
|
||||
redirect_params = {}
|
||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||
|
||||
if state:
|
||||
redirect_params["state"] = state
|
||||
elif "okta" in generic_authorization_endpoint:
|
||||
redirect_params["state"] = (
|
||||
uuid.uuid4().hex
|
||||
) # set state param for okta - required
|
||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||
raise ValueError(
|
||||
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_use_sso_handler(
|
||||
google_client_id: Optional[str] = None,
|
||||
microsoft_client_id: Optional[str] = None,
|
||||
generic_client_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
if (
|
||||
google_client_id is not None
|
||||
or microsoft_client_id is not None
|
||||
or generic_client_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_redirect_url_for_sso(
|
||||
request: Request,
|
||||
sso_callback_route: str,
|
||||
) -> str:
|
||||
"""
|
||||
Get the redirect URL for SSO
|
||||
"""
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
if redirect_url.endswith("/"):
|
||||
redirect_url += sso_callback_route
|
||||
else:
|
||||
redirect_url += "/" + sso_callback_route
|
||||
return redirect_url
|
||||
|
||||
|
||||
class MicrosoftSSOHandler:
|
||||
"""
|
||||
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_microsoft_callback_response(
|
||||
request: Request,
|
||||
microsoft_client_id: str,
|
||||
redirect_url: str,
|
||||
jwt_handler: JWTHandler,
|
||||
return_raw_sso_response: bool = False,
|
||||
) -> Union[CustomOpenID, OpenID, dict]:
|
||||
"""
|
||||
Get the Microsoft SSO callback response
|
||||
|
||||
Args:
|
||||
return_raw_sso_response: If True, return the raw SSO response
|
||||
"""
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if microsoft_tenant is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_TENANT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
original_msft_result = await microsoft_sso.verify_and_process(
|
||||
request=request,
|
||||
convert_response=False,
|
||||
)
|
||||
|
||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||
if return_raw_sso_response:
|
||||
return original_msft_result or {}
|
||||
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
response=original_msft_result,
|
||||
jwt_handler=jwt_handler,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def openid_from_response(
|
||||
response: Optional[dict], jwt_handler: JWTHandler
|
||||
|
@ -811,3 +912,181 @@ class MicrosoftSSOHandler:
|
|||
)
|
||||
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
||||
return openid_response
|
||||
|
||||
|
||||
class GoogleSSOHandler:
|
||||
"""
|
||||
Handles Google SSO callback response and returns a CustomOpenID object
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_google_callback_response(
|
||||
request: Request,
|
||||
google_client_id: str,
|
||||
redirect_url: str,
|
||||
return_raw_sso_response: bool = False,
|
||||
) -> Union[OpenID, dict]:
|
||||
"""
|
||||
Get the Google SSO callback response
|
||||
|
||||
Args:
|
||||
return_raw_sso_response: If True, return the raw SSO response
|
||||
"""
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
redirect_uri=redirect_url,
|
||||
client_secret=google_client_secret,
|
||||
)
|
||||
|
||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||
if return_raw_sso_response:
|
||||
return (
|
||||
await google_sso.verify_and_process(
|
||||
request=request,
|
||||
convert_response=False,
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
result = await google_sso.verify_and_process(request)
|
||||
return result or {}
|
||||
|
||||
|
||||
@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False)
|
||||
async def debug_sso_login(request: Request):
|
||||
"""
|
||||
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||
Example:
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
####### Check if user is a Enterprise / Premium User #######
|
||||
if (
|
||||
microsoft_client_id is not None
|
||||
or google_client_id is not None
|
||||
or generic_client_id is not None
|
||||
):
|
||||
if premium_user is not True:
|
||||
raise ProxyException(
|
||||
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="premium_user",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# get url from request
|
||||
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||||
request=request,
|
||||
sso_callback_route="sso/debug/callback",
|
||||
)
|
||||
|
||||
# Check if we should use SSO handler
|
||||
if (
|
||||
SSOAuthenticationHandler.should_use_sso_handler(
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
is True
|
||||
):
|
||||
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||||
redirect_url=redirect_url,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False)
|
||||
async def debug_sso_callback(request: Request):
|
||||
"""
|
||||
Returns the OpenID object returned by the SSO provider
|
||||
"""
|
||||
import json
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from litellm.proxy.proxy_server import jwt_handler
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
if redirect_url.endswith("/"):
|
||||
redirect_url += "sso/debug/callback"
|
||||
else:
|
||||
redirect_url += "/sso/debug/callback"
|
||||
|
||||
result = None
|
||||
if google_client_id is not None:
|
||||
result = await GoogleSSOHandler.get_google_callback_response(
|
||||
request=request,
|
||||
google_client_id=google_client_id,
|
||||
redirect_url=redirect_url,
|
||||
return_raw_sso_response=True,
|
||||
)
|
||||
elif microsoft_client_id is not None:
|
||||
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||
request=request,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
redirect_url=redirect_url,
|
||||
jwt_handler=jwt_handler,
|
||||
return_raw_sso_response=True,
|
||||
)
|
||||
elif generic_client_id is not None:
|
||||
result = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
# If result is None, return a basic error message
|
||||
if result is None:
|
||||
return HTMLResponse(
|
||||
content="<h1>SSO Authentication Failed</h1><p>No data was returned from the SSO provider.</p>",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Convert the OpenID object to a dictionary
|
||||
if hasattr(result, "__dict__"):
|
||||
result_dict = result.__dict__
|
||||
else:
|
||||
result_dict = dict(result)
|
||||
|
||||
# Filter out any None values and convert to JSON serializable format
|
||||
filtered_result = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is not None and not key.startswith("_"):
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
filtered_result[key] = value
|
||||
else:
|
||||
try:
|
||||
# Try to convert to string or another JSON serializable format
|
||||
filtered_result[key] = str(value)
|
||||
except Exception as e:
|
||||
filtered_result[key] = f"Complex value (not displayable): {str(e)}"
|
||||
|
||||
# Replace the placeholder in the template with the actual data
|
||||
html_content = jwt_display_template.replace(
|
||||
"const userData = SSO_DATA;",
|
||||
f"const userData = {json.dumps(filtered_result, indent=2)};",
|
||||
)
|
||||
|
||||
return HTMLResponse(content=html_content)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue