From cf7dcd9168c5481e149c2e6515152037f22899e9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 20 Sep 2024 19:14:33 -0700 Subject: [PATCH] [Feat-Proxy] Allow using custom sso handler (#5809) * update internal user doc string * add readme on location of /sso routes * add custom_sso_handler * docs custom sso * use secure=True for cookies --- docs/my-website/docs/proxy/custom_sso.md | 83 +++ docs/my-website/sidebars.js | 2 +- litellm/proxy/README.md | 3 +- litellm/proxy/custom_sso.py | 49 ++ .../internal_user_endpoints.py | 1 + litellm/proxy/management_endpoints/ui_sso.py | 618 ++++++++++++++++++ litellm/proxy/proxy_config.yaml | 1 + litellm/proxy/proxy_server.py | 591 +---------------- 8 files changed, 769 insertions(+), 579 deletions(-) create mode 100644 docs/my-website/docs/proxy/custom_sso.md create mode 100644 litellm/proxy/custom_sso.py create mode 100644 litellm/proxy/management_endpoints/ui_sso.py diff --git a/docs/my-website/docs/proxy/custom_sso.md b/docs/my-website/docs/proxy/custom_sso.md new file mode 100644 index 000000000..a89de0f32 --- /dev/null +++ b/docs/my-website/docs/proxy/custom_sso.md @@ -0,0 +1,83 @@ +# Event Hook for SSO Login (Custom Handler) + +Use this if you want to run your own code after a user signs on to the LiteLLM UI using SSO + +## How it works +- User lands on Admin UI +- LiteLLM redirects user to your SSO provider +- Your SSO provider redirects user back to LiteLLM +- LiteLLM has retrieved user information from your IDP +- **Your custom SSO handler is called and returns an object of type SSOUserDefinedValues** +- User signed in to UI + +## Usage + +#### 1. Create a custom sso handler file. + +Make sure the response type follows the `SSOUserDefinedValues` pydantic object. This is used for logging the user into the Admin UI + +```python +from fastapi import Request +from fastapi_sso.sso.base import OpenID + +from litellm.proxy._types import LitellmUserRoles, SSOUserDefinedValues +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + new_user, + user_info, +) +from litellm.proxy.management_endpoints.team_endpoints import add_new_member + + +async def custom_sso_handler(userIDPInfo: OpenID) -> SSOUserDefinedValues: + try: + print("inside custom sso handler") # noqa + print(f"userIDPInfo: {userIDPInfo}") # noqa + + if userIDPInfo.id is None: + raise ValueError( + f"No ID found for user. userIDPInfo.id is None {userIDPInfo}" + ) + + + ################################################# + # Run you custom code / logic here + # check if user exists in litellm proxy DB + _user_info = await user_info(user_id=userIDPInfo.id) + print("_user_info from litellm DB ", _user_info) # noqa + ################################################# + + return SSOUserDefinedValues( + models=[], # models user has access to + user_id=userIDPInfo.id, # user id to use in the LiteLLM DB + user_email=userIDPInfo.email, # user email to use in the LiteLLM DB + user_role=LitellmUserRoles.INTERNAL_USER.value, # role to use for the user + max_budget=0.01, # Max budget for this UI login Session + budget_duration="1d", # Duration of the budget for this UI login Session, 1d, 2d, 30d ... + ) + except Exception as e: + raise Exception("Failed custom auth") +``` + +#### 2. Pass the filepath (relative to the config.yaml) + +Pass the filepath to the config.yaml + +e.g. if they're both in the same dir - `./config.yaml` and `./custom_sso.py`, this is what it looks like: +```yaml +model_list: + - model_name: "openai-model" + litellm_params: + model: "gpt-3.5-turbo" + +litellm_settings: + drop_params: True + set_verbose: True + +general_settings: + custom_sso: custom_sso.custom_sso_handler +``` + +#### 3. Start the proxy +```shell +$ litellm --config /path/to/config.yaml +``` diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 813fc75a6..b6ac55bb9 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -65,7 +65,7 @@ const sidebars = { { type: "category", label: "Admin UI", - items: ["proxy/ui", "proxy/self_serve"], + items: ["proxy/ui", "proxy/self_serve", "proxy/custom_sso"], }, { type: "category", diff --git a/litellm/proxy/README.md b/litellm/proxy/README.md index 552c9777b..6c0d3f984 100644 --- a/litellm/proxy/README.md +++ b/litellm/proxy/README.md @@ -40,4 +40,5 @@ print(response) - `health_endpoints/` - `/health`, `/health/liveliness`, `/health/readiness` - `management_endpoints/key_management_endpoints.py` - all `/key/*` routes - `management_endpoints/team_endpoints.py` - all `/team/*` routes -- `management_endpoints/internal_user_endpoints.py` - all `/user/*` routes \ No newline at end of file +- `management_endpoints/internal_user_endpoints.py` - all `/user/*` routes +- `management_endpoints/ui_sso.py` - all `/sso/*` routes \ No newline at end of file diff --git a/litellm/proxy/custom_sso.py b/litellm/proxy/custom_sso.py new file mode 100644 index 000000000..10ff8910f --- /dev/null +++ b/litellm/proxy/custom_sso.py @@ -0,0 +1,49 @@ +""" +Example Custom SSO Handler + +Use this if you want to run custom code after litellm has retrieved information from your IDP (Identity Provider). + +Flow: +- User lands on Admin UI +- LiteLLM redirects user to your SSO provider +- Your SSO provider redirects user back to LiteLLM +- LiteLLM has retrieved user information from your IDP +- Your custom SSO handler is called and returns an object of type SSOUserDefinedValues +- User signed in to UI +""" + +from fastapi import Request +from fastapi_sso.sso.base import OpenID + +from litellm.proxy._types import LitellmUserRoles, SSOUserDefinedValues +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + new_user, + user_info, +) +from litellm.proxy.management_endpoints.team_endpoints import add_new_member + + +async def custom_sso_handler(userIDPInfo: OpenID) -> SSOUserDefinedValues: + try: + print("inside custom sso handler") # noqa + print(f"userIDPInfo: {userIDPInfo}") # noqa + + if userIDPInfo.id is None: + raise ValueError( + f"No ID found for user. userIDPInfo.id is None {userIDPInfo}" + ) + + # check if user exists in litellm proxy DB + _user_info = await user_info(user_id=userIDPInfo.id) + print("_user_info from litellm DB ", _user_info) # noqa + + return SSOUserDefinedValues( + models=[], + user_id=userIDPInfo.id, + user_email=userIDPInfo.email, + user_role=LitellmUserRoles.INTERNAL_USER.value, + max_budget=10, + budget_duration="1d", + ) + except Exception as e: + raise Exception("Failed custom auth") diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 3801d7465..771234557 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -7,6 +7,7 @@ These are members of a Team on LiteLLM /user/new /user/update /user/delete +/user/info """ import asyncio diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py new file mode 100644 index 000000000..d964c77a8 --- /dev/null +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -0,0 +1,618 @@ +""" +Has all /sso/* routes + +/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback +/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI +""" + +import asyncio +import os +import uuid +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + LitellmUserRoles, + ProxyErrorTypes, + ProxyException, + SSOUserDefinedValues, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.admin_ui_utils import ( + admin_ui_disabled, + html_form, + show_missing_vars_in_env, +) +from litellm.secret_managers.main import str_to_bool + +router = APIRouter() + + +@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) +async def google_login(request: Request): + """ + Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env + PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" + Example: + """ + from litellm.proxy.proxy_server import master_key, premium_user, prisma_client + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + ####### Check if UI is disabled ####### + _disable_ui_flag = os.getenv("DISABLE_ADMIN_UI") + if _disable_ui_flag is not None: + is_disabled = str_to_bool(value=_disable_ui_flag) + if is_disabled: + return admin_ui_disabled() + + ####### Check if user is a Enterprise / Premium User ####### + if ( + microsoft_client_id is not None + or google_client_id is not None + or generic_client_id is not None + ): + if premium_user != True: + raise ProxyException( + message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", + type=ProxyErrorTypes.auth_error, + param="premium_user", + code=status.HTTP_403_FORBIDDEN, + ) + + ####### Detect DB + MASTER KEY in .env ####### + missing_env_vars = show_missing_vars_in_env() + if missing_env_vars is not None: + return missing_env_vars + + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + ui_username = os.getenv("UI_USERNAME") + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + # Google SSO Auth + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + client_secret=google_client_secret, + redirect_uri=redirect_url, + ) + verbose_proxy_logger.info( + f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + ) + with google_sso: + return await google_sso.get_login_redirect() + # Microsoft SSO Auth + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + with microsoft_sso: + return await microsoft_sso.get_login_redirect() + elif generic_client_id is not None: + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + with generic_sso: + # TODO: state should be a random string and added to the user session with cookie + # or a cryptographicly signed state that we can verify stateless + # For simplification we are using a static state, this is not perfect but some + # SSO providers do not allow stateless verification + redirect_params = {} + state = os.getenv("GENERIC_CLIENT_STATE", None) + + if state: + redirect_params["state"] = state + elif "okta" in generic_authorization_endpoint: + redirect_params["state"] = ( + uuid.uuid4().hex + ) # set state param for okta - required + return await generic_sso.get_login_redirect(**redirect_params) # type: ignore + elif ui_username is not None: + # No Google, Microsoft SSO + # Use UI Credentials set in .env + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + else: + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + + +@router.get("/sso/callback", tags=["experimental"], include_in_schema=False) +async def auth_callback(request: Request): + """Verify login""" + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + ) + from litellm.proxy.proxy_server import ( + general_settings, + master_key, + premium_user, + prisma_client, + ui_access_mode, + user_custom_sso, + ) + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + # get url from request + if master_key is None: + raise ProxyException( + message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.", + type=ProxyErrorTypes.auth_error, + param="master_key", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + + result = None + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, + ) + result = await google_sso.verify_and_process(request) + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + result = await microsoft_sso.verify_and_process(request) + elif generic_client_id is not None: + # make generic sso provider + from fastapi_sso.sso.base import DiscoveryDocument, OpenID + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + generic_include_client_id = ( + os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" + ) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + + generic_user_id_attribute_name = os.getenv( + "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" + ) + generic_user_display_name_attribute_name = os.getenv( + "GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub" + ) + generic_user_email_attribute_name = os.getenv( + "GENERIC_USER_EMAIL_ATTRIBUTE", "email" + ) + generic_user_role_attribute_name = os.getenv( + "GENERIC_USER_ROLE_ATTRIBUTE", "role" + ) + generic_user_first_name_attribute_name = os.getenv( + "GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name" + ) + generic_user_last_name_attribute_name = os.getenv( + "GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name" + ) + + verbose_proxy_logger.debug( + f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}" + ) + + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + + def response_convertor(response, client): + return OpenID( + id=response.get(generic_user_id_attribute_name), + display_name=response.get(generic_user_display_name_attribute_name), + email=response.get(generic_user_email_attribute_name), + first_name=response.get(generic_user_first_name_attribute_name), + last_name=response.get(generic_user_last_name_attribute_name), + ) + + SSOProvider = create_provider( + name="oidc", + discovery_document=discovery, + response_convertor=response_convertor, + ) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + verbose_proxy_logger.debug("calling generic_sso.verify_and_process") + result = await generic_sso.verify_and_process( + request, params={"include_client_id": generic_include_client_id} + ) + verbose_proxy_logger.debug("generic result: %s", result) + + # User is Authe'd in - generate key for the UI to access Proxy + user_email: Optional[str] = getattr(result, "email", None) + user_id: Optional[str] = getattr(result, "id", None) if result is not None else None + + if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None: + email_domain = user_email.split("@")[1] + allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore + if email_domain not in allowed_domains: + raise HTTPException( + status_code=401, + detail={ + "message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format( + email_domain, allowed_domains + ) + }, + ) + + # generic client id + if generic_client_id is not None and result is not None: + user_id = getattr(result, "id", None) + user_email = getattr(result, "email", None) + user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore + + if user_id is None and result is not None: + _first_name = getattr(result, "first_name", "") or "" + _last_name = getattr(result, "last_name", "") or "" + user_id = _first_name + _last_name + + if user_email is not None and (user_id is None or len(user_id) == 0): + user_id = user_email + + user_info = None + user_id_models: List = [] + max_internal_user_budget = litellm.max_internal_user_budget + internal_user_budget_duration = litellm.internal_user_budget_duration + + # User might not be already created on first generation of key + # But if it is, we want their models preferences + default_ui_key_values = { + "duration": "24hr", + "key_max_budget": 0.01, + "aliases": {}, + "config": {}, + "spend": 0, + "team_id": "litellm-dashboard", + } + user_defined_values: Optional[SSOUserDefinedValues] = None + + if user_custom_sso is not None: + if asyncio.iscoroutinefunction(user_custom_sso): + user_defined_values = await user_custom_sso(result) # type: ignore + else: + raise ValueError("user_custom_sso must be a coroutine function") + elif user_id is not None: + user_defined_values = SSOUserDefinedValues( + models=user_id_models, + user_id=user_id, + user_email=user_email, + max_budget=max_internal_user_budget, + user_role=None, + budget_duration=internal_user_budget_duration, + ) + + _user_id_from_sso = user_id + user_role = None + try: + if prisma_client is not None: + user_info = await prisma_client.get_data(user_id=user_id, table_name="user") + verbose_proxy_logger.debug( + f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}" + ) + if user_info is None: + ## check if user-email in db ## + user_info = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email} + ) + + if user_info is not None and user_id is not None: + user_defined_values = SSOUserDefinedValues( + models=getattr(user_info, "models", user_id_models), + user_id=user_id, + user_email=getattr(user_info, "user_email", user_email), + user_role=getattr(user_info, "user_role", None), + max_budget=getattr( + user_info, "max_budget", max_internal_user_budget + ), + budget_duration=getattr( + user_info, "budget_duration", internal_user_budget_duration + ), + ) + + user_role = getattr(user_info, "user_role", None) + + # update id + await prisma_client.db.litellm_usertable.update_many( + where={"user_email": user_email}, data={"user_id": user_id} # type: ignore + ) + elif litellm.default_user_params is not None and isinstance( + litellm.default_user_params, dict + ): + user_defined_values = { + "models": litellm.default_user_params.get("models", user_id_models), + "user_id": litellm.default_user_params.get("user_id", user_id), + "user_email": litellm.default_user_params.get( + "user_email", user_email + ), + "user_role": litellm.default_user_params.get("user_role", None), + "max_budget": litellm.default_user_params.get( + "max_budget", max_internal_user_budget + ), + "budget_duration": litellm.default_user_params.get( + "budget_duration", internal_user_budget_duration + ), + } + + except Exception as e: + pass + + if user_defined_values is None: + raise Exception( + "Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues" + ) + + is_internal_user = False + if ( + user_defined_values["user_role"] is not None + and user_defined_values["user_role"] == LitellmUserRoles.INTERNAL_USER.value + ): + is_internal_user = True + if ( + is_internal_user is True + and user_defined_values["max_budget"] is None + and litellm.max_internal_user_budget is not None + ): + user_defined_values["max_budget"] = litellm.max_internal_user_budget + + if ( + is_internal_user is True + and user_defined_values["budget_duration"] is None + and litellm.internal_user_budget_duration is not None + ): + user_defined_values["budget_duration"] = litellm.internal_user_budget_duration + + verbose_proxy_logger.info( + f"user_defined_values for creating ui key: {user_defined_values}" + ) + + default_ui_key_values.update(user_defined_values) + default_ui_key_values["request_type"] = "key" + response = await generate_key_helper_fn( + **default_ui_key_values, # type: ignore + ) + key = response["token"] # type: ignore + user_id = response["user_id"] # type: ignore + + # This should always be true + # User_id on SSO == user_id in the LiteLLM_VerificationToken Table + assert user_id == _user_id_from_sso + litellm_dashboard_ui = "/ui/" + user_role = user_role or "app_owner" + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ): + # checks if user is admin + user_role = "app_admin" + + verbose_proxy_logger.debug( + f"user_role: {user_role}; ui_access_mode: {ui_access_mode}" + ) + ## CHECK IF ROLE ALLOWED TO USE PROXY ## + if ui_access_mode == "admin_only" and "admin" not in user_role: + verbose_proxy_logger.debug("EXCEPTION RAISED") + raise HTTPException( + status_code=401, + detail={ + "error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}" + }, + ) + + import jwt + + jwt_token = jwt.encode( # type: ignore + { + "user_id": user_id, + "key": key, + "user_email": user_email, + "user_role": user_role, + "login_method": "sso", + "premium_user": premium_user, + "auth_header_name": general_settings.get( + "litellm_key_header_name", "Authorization" + ), + }, + master_key, + algorithm="HS256", + ) + if user_id is not None and isinstance(user_id, str): + litellm_dashboard_ui += "?userID=" + user_id + redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) + redirect_response.set_cookie(key="token", value=jwt_token, secure=True) + return redirect_response + + +@router.get( + "/sso/get/ui_settings", + tags=["experimental"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def get_ui_settings(request: Request): + from litellm.proxy.proxy_server import general_settings + + _proxy_base_url = os.getenv("PROXY_BASE_URL", None) + _logout_url = os.getenv("PROXY_LOGOUT_URL", None) + + default_team_disabled = general_settings.get("default_team_disabled", False) + if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ: + if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true": + default_team_disabled = True + + return { + "PROXY_BASE_URL": _proxy_base_url, + "PROXY_LOGOUT_URL": _logout_url, + "DEFAULT_TEAM_DISABLED": default_team_disabled, + } diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 1795adf71..e20aa8c28 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -25,6 +25,7 @@ model_list: general_settings: master_key: sk-1234 default_team_disabled: true + custom_sso: custom_sso.custom_sso_handler litellm_settings: success_callback: ["prometheus"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9f3eacf18..61e6879fd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -138,11 +138,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth ## Import All Misc routes here ## from litellm.proxy.caching_routes import router as caching_router -from litellm.proxy.common_utils.admin_ui_utils import ( - admin_ui_disabled, - html_form, - show_missing_vars_in_env, -) +from litellm.proxy.common_utils.admin_ui_utils import html_form from litellm.proxy.common_utils.callback_utils import ( get_logging_caching_headers, get_remaining_tokens_and_requests_from_request_data, @@ -194,6 +190,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import ( router as team_callback_router, ) from litellm.proxy.management_endpoints.team_endpoints import router as team_router +from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, @@ -488,6 +485,7 @@ redis_usage_cache: Optional[RedisCache] = ( ) user_custom_auth = None user_custom_key_generate = None +user_custom_sso = None use_background_health_checks = None use_queue = False health_check_interval = None @@ -1484,7 +1482,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings # Load existing config if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: @@ -1846,6 +1844,13 @@ class ProxyConfig: user_custom_key_generate = get_instance_fn( value=custom_key_generate, config_file_path=config_file_path ) + + custom_sso = general_settings.get("custom_sso", None) + if custom_sso is not None: + user_custom_sso = get_instance_fn( + value=custom_sso, config_file_path=config_file_path + ) + ## pass through endpoints if general_settings.get("pass_through_endpoints", None) is not None: await initialize_pass_through_endpoints( @@ -7964,184 +7969,6 @@ async def async_queue_request( ) -#### LOGIN ENDPOINTS #### - - -@app.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) -async def google_login(request: Request): - """ - Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env - PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" - Example: - """ - global premium_user, prisma_client, master_key - - microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) - google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) - - ####### Check if UI is disabled ####### - _disable_ui_flag = os.getenv("DISABLE_ADMIN_UI") - if _disable_ui_flag is not None: - is_disabled = str_to_bool(value=_disable_ui_flag) - if is_disabled: - return admin_ui_disabled() - - ####### Check if user is a Enterprise / Premium User ####### - if ( - microsoft_client_id is not None - or google_client_id is not None - or generic_client_id is not None - ): - if premium_user != True: - raise ProxyException( - message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", - type=ProxyErrorTypes.auth_error, - param="premium_user", - code=status.HTTP_403_FORBIDDEN, - ) - - ####### Detect DB + MASTER KEY in .env ####### - missing_env_vars = show_missing_vars_in_env() - if missing_env_vars is not None: - return missing_env_vars - - # get url from request - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - ui_username = os.getenv("UI_USERNAME") - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - # Google SSO Auth - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - client_secret=google_client_secret, - redirect_uri=redirect_url, - ) - verbose_proxy_logger.info( - f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" - ) - with google_sso: - return await google_sso.get_login_redirect() - # Microsoft SSO Auth - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - with microsoft_sso: - return await microsoft_sso.get_login_redirect() - elif generic_client_id is not None: - from fastapi_sso.sso.base import DiscoveryDocument - from fastapi_sso.sso.generic import create_provider - - generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) - generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") - generic_authorization_endpoint = os.getenv( - "GENERIC_AUTHORIZATION_ENDPOINT", None - ) - generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) - generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) - if generic_client_secret is None: - raise ProxyException( - message="GENERIC_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_authorization_endpoint is None: - raise ProxyException( - message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_AUTHORIZATION_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_token_endpoint is None: - raise ProxyException( - message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_TOKEN_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_userinfo_endpoint is None: - raise ProxyException( - message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_USERINFO_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - verbose_proxy_logger.debug( - f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" - ) - verbose_proxy_logger.debug( - f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" - ) - discovery = DiscoveryDocument( - authorization_endpoint=generic_authorization_endpoint, - token_endpoint=generic_token_endpoint, - userinfo_endpoint=generic_userinfo_endpoint, - ) - SSOProvider = create_provider(name="oidc", discovery_document=discovery) - generic_sso = SSOProvider( - client_id=generic_client_id, - client_secret=generic_client_secret, - redirect_uri=redirect_url, - allow_insecure_http=True, - scope=generic_scope, - ) - with generic_sso: - # TODO: state should be a random string and added to the user session with cookie - # or a cryptographicly signed state that we can verify stateless - # For simplification we are using a static state, this is not perfect but some - # SSO providers do not allow stateless verification - redirect_params = {} - state = os.getenv("GENERIC_CLIENT_STATE", None) - - if state: - redirect_params["state"] = state - elif "okta" in generic_authorization_endpoint: - redirect_params["state"] = ( - uuid.uuid4().hex - ) # set state param for okta - required - return await generic_sso.get_login_redirect(**redirect_params) # type: ignore - elif ui_username is not None: - # No Google, Microsoft SSO - # Use UI Credentials set in .env - from fastapi.responses import HTMLResponse - - return HTMLResponse(content=html_form, status_code=200) - else: - from fastapi.responses import HTMLResponse - - return HTMLResponse(content=html_form, status_code=200) - - @app.get("/fallback/login", tags=["experimental"], include_in_schema=False) async def fallback_login(request: Request): """ @@ -8371,28 +8198,6 @@ async def login(request: Request): ) -@app.get( - "/sso/get/ui_settings", - tags=["experimental"], - include_in_schema=False, - dependencies=[Depends(user_api_key_auth)], -) -async def get_ui_settings(request: Request): - _proxy_base_url = os.getenv("PROXY_BASE_URL", None) - _logout_url = os.getenv("PROXY_LOGOUT_URL", None) - - default_team_disabled = general_settings.get("default_team_disabled", False) - if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ: - if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true": - default_team_disabled = True - - return { - "PROXY_BASE_URL": _proxy_base_url, - "PROXY_LOGOUT_URL": _logout_url, - "DEFAULT_TEAM_DISABLED": default_team_disabled, - } - - @app.get("/onboarding/get_token", include_in_schema=False) async def onboarding(invite_link: str): """ @@ -8613,376 +8418,6 @@ def get_image(): return FileResponse(logo_path, media_type="image/jpeg") -@app.get("/sso/callback", tags=["experimental"], include_in_schema=False) -async def auth_callback(request: Request): - """Verify login""" - global general_settings, ui_access_mode, premium_user, master_key - microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) - google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) - # get url from request - if master_key is None: - raise ProxyException( - message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.", - type=ProxyErrorTypes.auth_error, - param="master_key", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - - result = None - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - redirect_uri=redirect_url, - client_secret=google_client_secret, - ) - result = await google_sso.verify_and_process(request) - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if microsoft_tenant is None: - raise ProxyException( - message="MICROSOFT_TENANT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="MICROSOFT_TENANT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - result = await microsoft_sso.verify_and_process(request) - elif generic_client_id is not None: - # make generic sso provider - from fastapi_sso.sso.base import DiscoveryDocument, OpenID - from fastapi_sso.sso.generic import create_provider - - generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) - generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") - generic_authorization_endpoint = os.getenv( - "GENERIC_AUTHORIZATION_ENDPOINT", None - ) - generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) - generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) - generic_include_client_id = ( - os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" - ) - if generic_client_secret is None: - raise ProxyException( - message="GENERIC_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_authorization_endpoint is None: - raise ProxyException( - message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_AUTHORIZATION_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_token_endpoint is None: - raise ProxyException( - message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_TOKEN_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_userinfo_endpoint is None: - raise ProxyException( - message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_USERINFO_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - verbose_proxy_logger.debug( - f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" - ) - verbose_proxy_logger.debug( - f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" - ) - - generic_user_id_attribute_name = os.getenv( - "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" - ) - generic_user_display_name_attribute_name = os.getenv( - "GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub" - ) - generic_user_email_attribute_name = os.getenv( - "GENERIC_USER_EMAIL_ATTRIBUTE", "email" - ) - generic_user_role_attribute_name = os.getenv( - "GENERIC_USER_ROLE_ATTRIBUTE", "role" - ) - generic_user_first_name_attribute_name = os.getenv( - "GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name" - ) - generic_user_last_name_attribute_name = os.getenv( - "GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name" - ) - - verbose_proxy_logger.debug( - f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}" - ) - - discovery = DiscoveryDocument( - authorization_endpoint=generic_authorization_endpoint, - token_endpoint=generic_token_endpoint, - userinfo_endpoint=generic_userinfo_endpoint, - ) - - def response_convertor(response, client): - return OpenID( - id=response.get(generic_user_id_attribute_name), - display_name=response.get(generic_user_display_name_attribute_name), - email=response.get(generic_user_email_attribute_name), - first_name=response.get(generic_user_first_name_attribute_name), - last_name=response.get(generic_user_last_name_attribute_name), - ) - - SSOProvider = create_provider( - name="oidc", - discovery_document=discovery, - response_convertor=response_convertor, - ) - generic_sso = SSOProvider( - client_id=generic_client_id, - client_secret=generic_client_secret, - redirect_uri=redirect_url, - allow_insecure_http=True, - scope=generic_scope, - ) - verbose_proxy_logger.debug("calling generic_sso.verify_and_process") - result = await generic_sso.verify_and_process( - request, params={"include_client_id": generic_include_client_id} - ) - verbose_proxy_logger.debug("generic result: %s", result) - - # User is Authe'd in - generate key for the UI to access Proxy - user_email: Optional[str] = getattr(result, "email", None) - user_id: Optional[str] = getattr(result, "id", None) if result is not None else None - - if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None: - email_domain = user_email.split("@")[1] - allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore - if email_domain not in allowed_domains: - raise HTTPException( - status_code=401, - detail={ - "message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format( - email_domain, allowed_domains - ) - }, - ) - - # generic client id - if generic_client_id is not None and result is not None: - user_id = getattr(result, "id", None) - user_email = getattr(result, "email", None) - user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore - - if user_id is None and result is not None: - _first_name = getattr(result, "first_name", "") or "" - _last_name = getattr(result, "last_name", "") or "" - user_id = _first_name + _last_name - - if user_email is not None and (user_id is None or len(user_id) == 0): - user_id = user_email - - user_info = None - user_id_models: List = [] - max_internal_user_budget = litellm.max_internal_user_budget - internal_user_budget_duration = litellm.internal_user_budget_duration - - # User might not be already created on first generation of key - # But if it is, we want their models preferences - default_ui_key_values = { - "duration": "24hr", - "key_max_budget": 0.01, - "aliases": {}, - "config": {}, - "spend": 0, - "team_id": "litellm-dashboard", - } - user_defined_values: Optional[SSOUserDefinedValues] = None - if user_id is not None: - user_defined_values = SSOUserDefinedValues( - models=user_id_models, - user_id=user_id, - user_email=user_email, - max_budget=max_internal_user_budget, - user_role=None, - budget_duration=internal_user_budget_duration, - ) - - _user_id_from_sso = user_id - user_role = None - try: - if prisma_client is not None: - user_info = await prisma_client.get_data(user_id=user_id, table_name="user") - verbose_proxy_logger.debug( - f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}" - ) - if user_info is None: - ## check if user-email in db ## - user_info = await prisma_client.db.litellm_usertable.find_first( - where={"user_email": user_email} - ) - - if user_info is not None and user_id is not None: - user_defined_values = SSOUserDefinedValues( - models=getattr(user_info, "models", user_id_models), - user_id=user_id, - user_email=getattr(user_info, "user_email", user_email), - user_role=getattr(user_info, "user_role", None), - max_budget=getattr( - user_info, "max_budget", max_internal_user_budget - ), - budget_duration=getattr( - user_info, "budget_duration", internal_user_budget_duration - ), - ) - - user_role = getattr(user_info, "user_role", None) - - # update id - await prisma_client.db.litellm_usertable.update_many( - where={"user_email": user_email}, data={"user_id": user_id} # type: ignore - ) - elif litellm.default_user_params is not None and isinstance( - litellm.default_user_params, dict - ): - user_defined_values = { - "models": litellm.default_user_params.get("models", user_id_models), - "user_id": litellm.default_user_params.get("user_id", user_id), - "user_email": litellm.default_user_params.get( - "user_email", user_email - ), - "user_role": litellm.default_user_params.get("user_role", None), - "max_budget": litellm.default_user_params.get( - "max_budget", max_internal_user_budget - ), - "budget_duration": litellm.default_user_params.get( - "budget_duration", internal_user_budget_duration - ), - } - - except Exception as e: - pass - - if user_defined_values is None: - raise Exception( - "Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues" - ) - - is_internal_user = False - if ( - user_defined_values["user_role"] is not None - and user_defined_values["user_role"] == LitellmUserRoles.INTERNAL_USER.value - ): - is_internal_user = True - if ( - is_internal_user is True - and user_defined_values["max_budget"] is None - and litellm.max_internal_user_budget is not None - ): - user_defined_values["max_budget"] = litellm.max_internal_user_budget - - if ( - is_internal_user is True - and user_defined_values["budget_duration"] is None - and litellm.internal_user_budget_duration is not None - ): - user_defined_values["budget_duration"] = litellm.internal_user_budget_duration - - verbose_proxy_logger.info( - f"user_defined_values for creating ui key: {user_defined_values}" - ) - - default_ui_key_values.update(user_defined_values) - default_ui_key_values["request_type"] = "key" - response = await generate_key_helper_fn( - **default_ui_key_values, # type: ignore - ) - key = response["token"] # type: ignore - user_id = response["user_id"] # type: ignore - - # This should always be true - # User_id on SSO == user_id in the LiteLLM_VerificationToken Table - assert user_id == _user_id_from_sso - litellm_dashboard_ui = "/ui/" - user_role = user_role or "app_owner" - if ( - os.getenv("PROXY_ADMIN_ID", None) is not None - and os.environ["PROXY_ADMIN_ID"] == user_id - ): - # checks if user is admin - user_role = "app_admin" - - verbose_proxy_logger.debug( - f"user_role: {user_role}; ui_access_mode: {ui_access_mode}" - ) - ## CHECK IF ROLE ALLOWED TO USE PROXY ## - if ui_access_mode == "admin_only" and "admin" not in user_role: - verbose_proxy_logger.debug("EXCEPTION RAISED") - raise HTTPException( - status_code=401, - detail={ - "error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}" - }, - ) - - import jwt - - jwt_token = jwt.encode( # type: ignore - { - "user_id": user_id, - "key": key, - "user_email": user_email, - "user_role": user_role, - "login_method": "sso", - "premium_user": premium_user, - "auth_header_name": general_settings.get( - "litellm_key_header_name", "Authorization" - ), - }, - master_key, - algorithm="HS256", - ) - if user_id is not None and isinstance(user_id, str): - litellm_dashboard_ui += "?userID=" + user_id - redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) - redirect_response.set_cookie(key="token", value=jwt_token) - return redirect_response - - #### INVITATION MANAGEMENT #### @@ -10044,7 +9479,7 @@ async def shutdown_event(): def cleanup_router_config_variables(): - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client, custom_db_client # Set all variables to None master_key = None @@ -10053,6 +9488,7 @@ def cleanup_router_config_variables(): user_custom_auth = None user_custom_auth_path = None user_custom_key_generate = None + user_custom_sso = None use_background_health_checks = None health_check_interval = None health_check_details = None @@ -10071,6 +9507,7 @@ app.include_router(health_router) app.include_router(key_management_router) app.include_router(internal_user_router) app.include_router(team_router) +app.include_router(ui_sso_router) app.include_router(spend_management_router) app.include_router(caching_router) app.include_router(analytics_router)