forked from phoenix/litellm-mirror
[Feat-Proxy] Allow using custom sso handler (#5809)
* update internal user doc string * add readme on location of /sso routes * add custom_sso_handler * docs custom sso * use secure=True for cookies
This commit is contained in:
parent
0a18b6539c
commit
cf7dcd9168
8 changed files with 769 additions and 579 deletions
83
docs/my-website/docs/proxy/custom_sso.md
Normal file
83
docs/my-website/docs/proxy/custom_sso.md
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Event Hook for SSO Login (Custom Handler)
|
||||||
|
|
||||||
|
Use this if you want to run your own code after a user signs on to the LiteLLM UI using SSO
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
- User lands on Admin UI
|
||||||
|
- LiteLLM redirects user to your SSO provider
|
||||||
|
- Your SSO provider redirects user back to LiteLLM
|
||||||
|
- LiteLLM has retrieved user information from your IDP
|
||||||
|
- **Your custom SSO handler is called and returns an object of type SSOUserDefinedValues**
|
||||||
|
- User signed in to UI
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
#### 1. Create a custom sso handler file.
|
||||||
|
|
||||||
|
Make sure the response type follows the `SSOUserDefinedValues` pydantic object. This is used for logging the user into the Admin UI
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi_sso.sso.base import OpenID
|
||||||
|
|
||||||
|
from litellm.proxy._types import LitellmUserRoles, SSOUserDefinedValues
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||||
|
new_user,
|
||||||
|
user_info,
|
||||||
|
)
|
||||||
|
from litellm.proxy.management_endpoints.team_endpoints import add_new_member
|
||||||
|
|
||||||
|
|
||||||
|
async def custom_sso_handler(userIDPInfo: OpenID) -> SSOUserDefinedValues:
|
||||||
|
try:
|
||||||
|
print("inside custom sso handler") # noqa
|
||||||
|
print(f"userIDPInfo: {userIDPInfo}") # noqa
|
||||||
|
|
||||||
|
if userIDPInfo.id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No ID found for user. userIDPInfo.id is None {userIDPInfo}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# Run you custom code / logic here
|
||||||
|
# check if user exists in litellm proxy DB
|
||||||
|
_user_info = await user_info(user_id=userIDPInfo.id)
|
||||||
|
print("_user_info from litellm DB ", _user_info) # noqa
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
return SSOUserDefinedValues(
|
||||||
|
models=[], # models user has access to
|
||||||
|
user_id=userIDPInfo.id, # user id to use in the LiteLLM DB
|
||||||
|
user_email=userIDPInfo.email, # user email to use in the LiteLLM DB
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER.value, # role to use for the user
|
||||||
|
max_budget=0.01, # Max budget for this UI login Session
|
||||||
|
budget_duration="1d", # Duration of the budget for this UI login Session, 1d, 2d, 30d ...
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Failed custom auth")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Pass the filepath (relative to the config.yaml)
|
||||||
|
|
||||||
|
Pass the filepath to the config.yaml
|
||||||
|
|
||||||
|
e.g. if they're both in the same dir - `./config.yaml` and `./custom_sso.py`, this is what it looks like:
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "openai-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
set_verbose: True
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
custom_sso: custom_sso.custom_sso_handler
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Start the proxy
|
||||||
|
```shell
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
|
@ -65,7 +65,7 @@ const sidebars = {
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Admin UI",
|
label: "Admin UI",
|
||||||
items: ["proxy/ui", "proxy/self_serve"],
|
items: ["proxy/ui", "proxy/self_serve", "proxy/custom_sso"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
|
|
|
@ -41,3 +41,4 @@ print(response)
|
||||||
- `management_endpoints/key_management_endpoints.py` - all `/key/*` routes
|
- `management_endpoints/key_management_endpoints.py` - all `/key/*` routes
|
||||||
- `management_endpoints/team_endpoints.py` - all `/team/*` routes
|
- `management_endpoints/team_endpoints.py` - all `/team/*` routes
|
||||||
- `management_endpoints/internal_user_endpoints.py` - all `/user/*` routes
|
- `management_endpoints/internal_user_endpoints.py` - all `/user/*` routes
|
||||||
|
- `management_endpoints/ui_sso.py` - all `/sso/*` routes
|
49
litellm/proxy/custom_sso.py
Normal file
49
litellm/proxy/custom_sso.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
"""
|
||||||
|
Example Custom SSO Handler
|
||||||
|
|
||||||
|
Use this if you want to run custom code after litellm has retrieved information from your IDP (Identity Provider).
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
- User lands on Admin UI
|
||||||
|
- LiteLLM redirects user to your SSO provider
|
||||||
|
- Your SSO provider redirects user back to LiteLLM
|
||||||
|
- LiteLLM has retrieved user information from your IDP
|
||||||
|
- Your custom SSO handler is called and returns an object of type SSOUserDefinedValues
|
||||||
|
- User signed in to UI
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi_sso.sso.base import OpenID
|
||||||
|
|
||||||
|
from litellm.proxy._types import LitellmUserRoles, SSOUserDefinedValues
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||||
|
new_user,
|
||||||
|
user_info,
|
||||||
|
)
|
||||||
|
from litellm.proxy.management_endpoints.team_endpoints import add_new_member
|
||||||
|
|
||||||
|
|
||||||
|
async def custom_sso_handler(userIDPInfo: OpenID) -> SSOUserDefinedValues:
|
||||||
|
try:
|
||||||
|
print("inside custom sso handler") # noqa
|
||||||
|
print(f"userIDPInfo: {userIDPInfo}") # noqa
|
||||||
|
|
||||||
|
if userIDPInfo.id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No ID found for user. userIDPInfo.id is None {userIDPInfo}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if user exists in litellm proxy DB
|
||||||
|
_user_info = await user_info(user_id=userIDPInfo.id)
|
||||||
|
print("_user_info from litellm DB ", _user_info) # noqa
|
||||||
|
|
||||||
|
return SSOUserDefinedValues(
|
||||||
|
models=[],
|
||||||
|
user_id=userIDPInfo.id,
|
||||||
|
user_email=userIDPInfo.email,
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
||||||
|
max_budget=10,
|
||||||
|
budget_duration="1d",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Failed custom auth")
|
|
@ -7,6 +7,7 @@ These are members of a Team on LiteLLM
|
||||||
/user/new
|
/user/new
|
||||||
/user/update
|
/user/update
|
||||||
/user/delete
|
/user/delete
|
||||||
|
/user/info
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
618
litellm/proxy/management_endpoints/ui_sso.py
Normal file
618
litellm/proxy/management_endpoints/ui_sso.py
Normal file
|
@ -0,0 +1,618 @@
|
||||||
|
"""
|
||||||
|
Has all /sso/* routes
|
||||||
|
|
||||||
|
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
|
||||||
|
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
LitellmUserRoles,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
SSOUserDefinedValues,
|
||||||
|
)
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
|
admin_ui_disabled,
|
||||||
|
html_form,
|
||||||
|
show_missing_vars_in_env,
|
||||||
|
)
|
||||||
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
|
||||||
|
async def google_login(request: Request):
|
||||||
|
"""
|
||||||
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||||||
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||||
|
Example:
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import master_key, premium_user, prisma_client
|
||||||
|
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
|
|
||||||
|
####### Check if UI is disabled #######
|
||||||
|
_disable_ui_flag = os.getenv("DISABLE_ADMIN_UI")
|
||||||
|
if _disable_ui_flag is not None:
|
||||||
|
is_disabled = str_to_bool(value=_disable_ui_flag)
|
||||||
|
if is_disabled:
|
||||||
|
return admin_ui_disabled()
|
||||||
|
|
||||||
|
####### Check if user is a Enterprise / Premium User #######
|
||||||
|
if (
|
||||||
|
microsoft_client_id is not None
|
||||||
|
or google_client_id is not None
|
||||||
|
or generic_client_id is not None
|
||||||
|
):
|
||||||
|
if premium_user != True:
|
||||||
|
raise ProxyException(
|
||||||
|
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="premium_user",
|
||||||
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
####### Detect DB + MASTER KEY in .env #######
|
||||||
|
missing_env_vars = show_missing_vars_in_env()
|
||||||
|
if missing_env_vars is not None:
|
||||||
|
return missing_env_vars
|
||||||
|
|
||||||
|
# get url from request
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
ui_username = os.getenv("UI_USERNAME")
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/callback"
|
||||||
|
else:
|
||||||
|
redirect_url += "/sso/callback"
|
||||||
|
# Google SSO Auth
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||||
|
)
|
||||||
|
with google_sso:
|
||||||
|
return await google_sso.get_login_redirect()
|
||||||
|
# Microsoft SSO Auth
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
with microsoft_sso:
|
||||||
|
return await microsoft_sso.get_login_redirect()
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
from fastapi_sso.sso.base import DiscoveryDocument
|
||||||
|
from fastapi_sso.sso.generic import create_provider
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
scope=generic_scope,
|
||||||
|
)
|
||||||
|
with generic_sso:
|
||||||
|
# TODO: state should be a random string and added to the user session with cookie
|
||||||
|
# or a cryptographicly signed state that we can verify stateless
|
||||||
|
# For simplification we are using a static state, this is not perfect but some
|
||||||
|
# SSO providers do not allow stateless verification
|
||||||
|
redirect_params = {}
|
||||||
|
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||||
|
|
||||||
|
if state:
|
||||||
|
redirect_params["state"] = state
|
||||||
|
elif "okta" in generic_authorization_endpoint:
|
||||||
|
redirect_params["state"] = (
|
||||||
|
uuid.uuid4().hex
|
||||||
|
) # set state param for okta - required
|
||||||
|
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||||
|
elif ui_username is not None:
|
||||||
|
# No Google, Microsoft SSO
|
||||||
|
# Use UI Credentials set in .env
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
return HTMLResponse(content=html_form, status_code=200)
|
||||||
|
else:
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
return HTMLResponse(content=html_form, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||||
|
async def auth_callback(request: Request):
|
||||||
|
"""Verify login"""
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
|
generate_key_helper_fn,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
general_settings,
|
||||||
|
master_key,
|
||||||
|
premium_user,
|
||||||
|
prisma_client,
|
||||||
|
ui_access_mode,
|
||||||
|
user_custom_sso,
|
||||||
|
)
|
||||||
|
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
|
# get url from request
|
||||||
|
if master_key is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="master_key",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/callback"
|
||||||
|
else:
|
||||||
|
redirect_url += "/sso/callback"
|
||||||
|
|
||||||
|
result = None
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
)
|
||||||
|
result = await google_sso.verify_and_process(request)
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if microsoft_tenant is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="MICROSOFT_TENANT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
# make generic sso provider
|
||||||
|
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
||||||
|
from fastapi_sso.sso.generic import create_provider
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
generic_include_client_id = (
|
||||||
|
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
generic_user_id_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||||||
|
)
|
||||||
|
generic_user_display_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
||||||
|
)
|
||||||
|
generic_user_email_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||||
|
)
|
||||||
|
generic_user_role_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||||
|
)
|
||||||
|
generic_user_first_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
||||||
|
)
|
||||||
|
generic_user_last_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
def response_convertor(response, client):
|
||||||
|
return OpenID(
|
||||||
|
id=response.get(generic_user_id_attribute_name),
|
||||||
|
display_name=response.get(generic_user_display_name_attribute_name),
|
||||||
|
email=response.get(generic_user_email_attribute_name),
|
||||||
|
first_name=response.get(generic_user_first_name_attribute_name),
|
||||||
|
last_name=response.get(generic_user_last_name_attribute_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(
|
||||||
|
name="oidc",
|
||||||
|
discovery_document=discovery,
|
||||||
|
response_convertor=response_convertor,
|
||||||
|
)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
scope=generic_scope,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
||||||
|
result = await generic_sso.verify_and_process(
|
||||||
|
request, params={"include_client_id": generic_include_client_id}
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("generic result: %s", result)
|
||||||
|
|
||||||
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
|
user_email: Optional[str] = getattr(result, "email", None)
|
||||||
|
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||||
|
|
||||||
|
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
||||||
|
email_domain = user_email.split("@")[1]
|
||||||
|
allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore
|
||||||
|
if email_domain not in allowed_domains:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format(
|
||||||
|
email_domain, allowed_domains
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# generic client id
|
||||||
|
if generic_client_id is not None and result is not None:
|
||||||
|
user_id = getattr(result, "id", None)
|
||||||
|
user_email = getattr(result, "email", None)
|
||||||
|
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||||||
|
|
||||||
|
if user_id is None and result is not None:
|
||||||
|
_first_name = getattr(result, "first_name", "") or ""
|
||||||
|
_last_name = getattr(result, "last_name", "") or ""
|
||||||
|
user_id = _first_name + _last_name
|
||||||
|
|
||||||
|
if user_email is not None and (user_id is None or len(user_id) == 0):
|
||||||
|
user_id = user_email
|
||||||
|
|
||||||
|
user_info = None
|
||||||
|
user_id_models: List = []
|
||||||
|
max_internal_user_budget = litellm.max_internal_user_budget
|
||||||
|
internal_user_budget_duration = litellm.internal_user_budget_duration
|
||||||
|
|
||||||
|
# User might not be already created on first generation of key
|
||||||
|
# But if it is, we want their models preferences
|
||||||
|
default_ui_key_values = {
|
||||||
|
"duration": "24hr",
|
||||||
|
"key_max_budget": 0.01,
|
||||||
|
"aliases": {},
|
||||||
|
"config": {},
|
||||||
|
"spend": 0,
|
||||||
|
"team_id": "litellm-dashboard",
|
||||||
|
}
|
||||||
|
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||||||
|
|
||||||
|
if user_custom_sso is not None:
|
||||||
|
if asyncio.iscoroutinefunction(user_custom_sso):
|
||||||
|
user_defined_values = await user_custom_sso(result) # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError("user_custom_sso must be a coroutine function")
|
||||||
|
elif user_id is not None:
|
||||||
|
user_defined_values = SSOUserDefinedValues(
|
||||||
|
models=user_id_models,
|
||||||
|
user_id=user_id,
|
||||||
|
user_email=user_email,
|
||||||
|
max_budget=max_internal_user_budget,
|
||||||
|
user_role=None,
|
||||||
|
budget_duration=internal_user_budget_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
_user_id_from_sso = user_id
|
||||||
|
user_role = None
|
||||||
|
try:
|
||||||
|
if prisma_client is not None:
|
||||||
|
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
||||||
|
)
|
||||||
|
if user_info is None:
|
||||||
|
## check if user-email in db ##
|
||||||
|
user_info = await prisma_client.db.litellm_usertable.find_first(
|
||||||
|
where={"user_email": user_email}
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_info is not None and user_id is not None:
|
||||||
|
user_defined_values = SSOUserDefinedValues(
|
||||||
|
models=getattr(user_info, "models", user_id_models),
|
||||||
|
user_id=user_id,
|
||||||
|
user_email=getattr(user_info, "user_email", user_email),
|
||||||
|
user_role=getattr(user_info, "user_role", None),
|
||||||
|
max_budget=getattr(
|
||||||
|
user_info, "max_budget", max_internal_user_budget
|
||||||
|
),
|
||||||
|
budget_duration=getattr(
|
||||||
|
user_info, "budget_duration", internal_user_budget_duration
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_role = getattr(user_info, "user_role", None)
|
||||||
|
|
||||||
|
# update id
|
||||||
|
await prisma_client.db.litellm_usertable.update_many(
|
||||||
|
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
||||||
|
)
|
||||||
|
elif litellm.default_user_params is not None and isinstance(
|
||||||
|
litellm.default_user_params, dict
|
||||||
|
):
|
||||||
|
user_defined_values = {
|
||||||
|
"models": litellm.default_user_params.get("models", user_id_models),
|
||||||
|
"user_id": litellm.default_user_params.get("user_id", user_id),
|
||||||
|
"user_email": litellm.default_user_params.get(
|
||||||
|
"user_email", user_email
|
||||||
|
),
|
||||||
|
"user_role": litellm.default_user_params.get("user_role", None),
|
||||||
|
"max_budget": litellm.default_user_params.get(
|
||||||
|
"max_budget", max_internal_user_budget
|
||||||
|
),
|
||||||
|
"budget_duration": litellm.default_user_params.get(
|
||||||
|
"budget_duration", internal_user_budget_duration
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if user_defined_values is None:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_internal_user = False
|
||||||
|
if (
|
||||||
|
user_defined_values["user_role"] is not None
|
||||||
|
and user_defined_values["user_role"] == LitellmUserRoles.INTERNAL_USER.value
|
||||||
|
):
|
||||||
|
is_internal_user = True
|
||||||
|
if (
|
||||||
|
is_internal_user is True
|
||||||
|
and user_defined_values["max_budget"] is None
|
||||||
|
and litellm.max_internal_user_budget is not None
|
||||||
|
):
|
||||||
|
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||||
|
|
||||||
|
if (
|
||||||
|
is_internal_user is True
|
||||||
|
and user_defined_values["budget_duration"] is None
|
||||||
|
and litellm.internal_user_budget_duration is not None
|
||||||
|
):
|
||||||
|
user_defined_values["budget_duration"] = litellm.internal_user_budget_duration
|
||||||
|
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"user_defined_values for creating ui key: {user_defined_values}"
|
||||||
|
)
|
||||||
|
|
||||||
|
default_ui_key_values.update(user_defined_values)
|
||||||
|
default_ui_key_values["request_type"] = "key"
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**default_ui_key_values, # type: ignore
|
||||||
|
)
|
||||||
|
key = response["token"] # type: ignore
|
||||||
|
user_id = response["user_id"] # type: ignore
|
||||||
|
|
||||||
|
# This should always be true
|
||||||
|
# User_id on SSO == user_id in the LiteLLM_VerificationToken Table
|
||||||
|
assert user_id == _user_id_from_sso
|
||||||
|
litellm_dashboard_ui = "/ui/"
|
||||||
|
user_role = user_role or "app_owner"
|
||||||
|
if (
|
||||||
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||||
|
):
|
||||||
|
# checks if user is admin
|
||||||
|
user_role = "app_admin"
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||||||
|
)
|
||||||
|
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
||||||
|
if ui_access_mode == "admin_only" and "admin" not in user_role:
|
||||||
|
verbose_proxy_logger.debug("EXCEPTION RAISED")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
jwt_token = jwt.encode( # type: ignore
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"key": key,
|
||||||
|
"user_email": user_email,
|
||||||
|
"user_role": user_role,
|
||||||
|
"login_method": "sso",
|
||||||
|
"premium_user": premium_user,
|
||||||
|
"auth_header_name": general_settings.get(
|
||||||
|
"litellm_key_header_name", "Authorization"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
master_key,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
if user_id is not None and isinstance(user_id, str):
|
||||||
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
|
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
||||||
|
return redirect_response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/sso/get/ui_settings",
|
||||||
|
tags=["experimental"],
|
||||||
|
include_in_schema=False,
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def get_ui_settings(request: Request):
|
||||||
|
from litellm.proxy.proxy_server import general_settings
|
||||||
|
|
||||||
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||||||
|
_logout_url = os.getenv("PROXY_LOGOUT_URL", None)
|
||||||
|
|
||||||
|
default_team_disabled = general_settings.get("default_team_disabled", False)
|
||||||
|
if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ:
|
||||||
|
if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true":
|
||||||
|
default_team_disabled = True
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PROXY_BASE_URL": _proxy_base_url,
|
||||||
|
"PROXY_LOGOUT_URL": _logout_url,
|
||||||
|
"DEFAULT_TEAM_DISABLED": default_team_disabled,
|
||||||
|
}
|
|
@ -25,6 +25,7 @@ model_list:
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
default_team_disabled: true
|
default_team_disabled: true
|
||||||
|
custom_sso: custom_sso.custom_sso_handler
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["prometheus"]
|
success_callback: ["prometheus"]
|
||||||
|
|
|
@ -138,11 +138,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
## Import All Misc routes here ##
|
## Import All Misc routes here ##
|
||||||
from litellm.proxy.caching_routes import router as caching_router
|
from litellm.proxy.caching_routes import router as caching_router
|
||||||
from litellm.proxy.common_utils.admin_ui_utils import (
|
from litellm.proxy.common_utils.admin_ui_utils import html_form
|
||||||
admin_ui_disabled,
|
|
||||||
html_form,
|
|
||||||
show_missing_vars_in_env,
|
|
||||||
)
|
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
get_logging_caching_headers,
|
get_logging_caching_headers,
|
||||||
get_remaining_tokens_and_requests_from_request_data,
|
get_remaining_tokens_and_requests_from_request_data,
|
||||||
|
@ -194,6 +190,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
|
||||||
router as team_callback_router,
|
router as team_callback_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||||
|
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
|
||||||
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
||||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||||
router as openai_files_router,
|
router as openai_files_router,
|
||||||
|
@ -488,6 +485,7 @@ redis_usage_cache: Optional[RedisCache] = (
|
||||||
)
|
)
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_key_generate = None
|
user_custom_key_generate = None
|
||||||
|
user_custom_sso = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
use_queue = False
|
use_queue = False
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
|
@ -1484,7 +1482,7 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Load config values into proxy global state
|
Load config values into proxy global state
|
||||||
"""
|
"""
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||||
|
@ -1846,6 +1844,13 @@ class ProxyConfig:
|
||||||
user_custom_key_generate = get_instance_fn(
|
user_custom_key_generate = get_instance_fn(
|
||||||
value=custom_key_generate, config_file_path=config_file_path
|
value=custom_key_generate, config_file_path=config_file_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_sso = general_settings.get("custom_sso", None)
|
||||||
|
if custom_sso is not None:
|
||||||
|
user_custom_sso = get_instance_fn(
|
||||||
|
value=custom_sso, config_file_path=config_file_path
|
||||||
|
)
|
||||||
|
|
||||||
## pass through endpoints
|
## pass through endpoints
|
||||||
if general_settings.get("pass_through_endpoints", None) is not None:
|
if general_settings.get("pass_through_endpoints", None) is not None:
|
||||||
await initialize_pass_through_endpoints(
|
await initialize_pass_through_endpoints(
|
||||||
|
@ -7964,184 +7969,6 @@ async def async_queue_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#### LOGIN ENDPOINTS ####
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
|
|
||||||
async def google_login(request: Request):
|
|
||||||
"""
|
|
||||||
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
|
||||||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
|
||||||
Example:
|
|
||||||
"""
|
|
||||||
global premium_user, prisma_client, master_key
|
|
||||||
|
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
|
||||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
|
||||||
|
|
||||||
####### Check if UI is disabled #######
|
|
||||||
_disable_ui_flag = os.getenv("DISABLE_ADMIN_UI")
|
|
||||||
if _disable_ui_flag is not None:
|
|
||||||
is_disabled = str_to_bool(value=_disable_ui_flag)
|
|
||||||
if is_disabled:
|
|
||||||
return admin_ui_disabled()
|
|
||||||
|
|
||||||
####### Check if user is a Enterprise / Premium User #######
|
|
||||||
if (
|
|
||||||
microsoft_client_id is not None
|
|
||||||
or google_client_id is not None
|
|
||||||
or generic_client_id is not None
|
|
||||||
):
|
|
||||||
if premium_user != True:
|
|
||||||
raise ProxyException(
|
|
||||||
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="premium_user",
|
|
||||||
code=status.HTTP_403_FORBIDDEN,
|
|
||||||
)
|
|
||||||
|
|
||||||
####### Detect DB + MASTER KEY in .env #######
|
|
||||||
missing_env_vars = show_missing_vars_in_env()
|
|
||||||
if missing_env_vars is not None:
|
|
||||||
return missing_env_vars
|
|
||||||
|
|
||||||
# get url from request
|
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
||||||
ui_username = os.getenv("UI_USERNAME")
|
|
||||||
if redirect_url.endswith("/"):
|
|
||||||
redirect_url += "sso/callback"
|
|
||||||
else:
|
|
||||||
redirect_url += "/sso/callback"
|
|
||||||
# Google SSO Auth
|
|
||||||
if google_client_id is not None:
|
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
|
||||||
|
|
||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
||||||
if google_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
google_sso = GoogleSSO(
|
|
||||||
client_id=google_client_id,
|
|
||||||
client_secret=google_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
|
||||||
)
|
|
||||||
with google_sso:
|
|
||||||
return await google_sso.get_login_redirect()
|
|
||||||
# Microsoft SSO Auth
|
|
||||||
elif microsoft_client_id is not None:
|
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
||||||
|
|
||||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
||||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
||||||
if microsoft_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
|
||||||
client_id=microsoft_client_id,
|
|
||||||
client_secret=microsoft_client_secret,
|
|
||||||
tenant=microsoft_tenant,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
with microsoft_sso:
|
|
||||||
return await microsoft_sso.get_login_redirect()
|
|
||||||
elif generic_client_id is not None:
|
|
||||||
from fastapi_sso.sso.base import DiscoveryDocument
|
|
||||||
from fastapi_sso.sso.generic import create_provider
|
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
|
||||||
generic_authorization_endpoint = os.getenv(
|
|
||||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
||||||
)
|
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
||||||
if generic_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_authorization_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_token_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_TOKEN_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_userinfo_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
||||||
)
|
|
||||||
discovery = DiscoveryDocument(
|
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
|
||||||
token_endpoint=generic_token_endpoint,
|
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
|
||||||
)
|
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
||||||
generic_sso = SSOProvider(
|
|
||||||
client_id=generic_client_id,
|
|
||||||
client_secret=generic_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
scope=generic_scope,
|
|
||||||
)
|
|
||||||
with generic_sso:
|
|
||||||
# TODO: state should be a random string and added to the user session with cookie
|
|
||||||
# or a cryptographicly signed state that we can verify stateless
|
|
||||||
# For simplification we are using a static state, this is not perfect but some
|
|
||||||
# SSO providers do not allow stateless verification
|
|
||||||
redirect_params = {}
|
|
||||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
|
||||||
|
|
||||||
if state:
|
|
||||||
redirect_params["state"] = state
|
|
||||||
elif "okta" in generic_authorization_endpoint:
|
|
||||||
redirect_params["state"] = (
|
|
||||||
uuid.uuid4().hex
|
|
||||||
) # set state param for okta - required
|
|
||||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
|
||||||
elif ui_username is not None:
|
|
||||||
# No Google, Microsoft SSO
|
|
||||||
# Use UI Credentials set in .env
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
|
|
||||||
return HTMLResponse(content=html_form, status_code=200)
|
|
||||||
else:
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
|
|
||||||
return HTMLResponse(content=html_form, status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/fallback/login", tags=["experimental"], include_in_schema=False)
|
@app.get("/fallback/login", tags=["experimental"], include_in_schema=False)
|
||||||
async def fallback_login(request: Request):
|
async def fallback_login(request: Request):
|
||||||
"""
|
"""
|
||||||
|
@ -8371,28 +8198,6 @@ async def login(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get(
|
|
||||||
"/sso/get/ui_settings",
|
|
||||||
tags=["experimental"],
|
|
||||||
include_in_schema=False,
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
)
|
|
||||||
async def get_ui_settings(request: Request):
|
|
||||||
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
|
||||||
_logout_url = os.getenv("PROXY_LOGOUT_URL", None)
|
|
||||||
|
|
||||||
default_team_disabled = general_settings.get("default_team_disabled", False)
|
|
||||||
if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ:
|
|
||||||
if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true":
|
|
||||||
default_team_disabled = True
|
|
||||||
|
|
||||||
return {
|
|
||||||
"PROXY_BASE_URL": _proxy_base_url,
|
|
||||||
"PROXY_LOGOUT_URL": _logout_url,
|
|
||||||
"DEFAULT_TEAM_DISABLED": default_team_disabled,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/onboarding/get_token", include_in_schema=False)
|
@app.get("/onboarding/get_token", include_in_schema=False)
|
||||||
async def onboarding(invite_link: str):
|
async def onboarding(invite_link: str):
|
||||||
"""
|
"""
|
||||||
|
@ -8613,376 +8418,6 @@ def get_image():
|
||||||
return FileResponse(logo_path, media_type="image/jpeg")
|
return FileResponse(logo_path, media_type="image/jpeg")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
|
||||||
async def auth_callback(request: Request):
|
|
||||||
"""Verify login"""
|
|
||||||
global general_settings, ui_access_mode, premium_user, master_key
|
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
|
||||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
|
||||||
# get url from request
|
|
||||||
if master_key is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="master_key",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
||||||
if redirect_url.endswith("/"):
|
|
||||||
redirect_url += "sso/callback"
|
|
||||||
else:
|
|
||||||
redirect_url += "/sso/callback"
|
|
||||||
|
|
||||||
result = None
|
|
||||||
if google_client_id is not None:
|
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
|
||||||
|
|
||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
||||||
if google_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
google_sso = GoogleSSO(
|
|
||||||
client_id=google_client_id,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
client_secret=google_client_secret,
|
|
||||||
)
|
|
||||||
result = await google_sso.verify_and_process(request)
|
|
||||||
elif microsoft_client_id is not None:
|
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
||||||
|
|
||||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
||||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
||||||
if microsoft_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if microsoft_tenant is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="MICROSOFT_TENANT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
|
||||||
client_id=microsoft_client_id,
|
|
||||||
client_secret=microsoft_client_secret,
|
|
||||||
tenant=microsoft_tenant,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
|
||||||
elif generic_client_id is not None:
|
|
||||||
# make generic sso provider
|
|
||||||
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
|
||||||
from fastapi_sso.sso.generic import create_provider
|
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
|
||||||
generic_authorization_endpoint = os.getenv(
|
|
||||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
||||||
)
|
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
||||||
generic_include_client_id = (
|
|
||||||
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
|
||||||
)
|
|
||||||
if generic_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_authorization_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_token_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_TOKEN_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_userinfo_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
generic_user_id_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
|
||||||
)
|
|
||||||
generic_user_display_name_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
|
||||||
)
|
|
||||||
generic_user_email_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
|
||||||
)
|
|
||||||
generic_user_role_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
|
||||||
)
|
|
||||||
generic_user_first_name_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
|
||||||
)
|
|
||||||
generic_user_last_name_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
|
||||||
token_endpoint=generic_token_endpoint,
|
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def response_convertor(response, client):
|
|
||||||
return OpenID(
|
|
||||||
id=response.get(generic_user_id_attribute_name),
|
|
||||||
display_name=response.get(generic_user_display_name_attribute_name),
|
|
||||||
email=response.get(generic_user_email_attribute_name),
|
|
||||||
first_name=response.get(generic_user_first_name_attribute_name),
|
|
||||||
last_name=response.get(generic_user_last_name_attribute_name),
|
|
||||||
)
|
|
||||||
|
|
||||||
SSOProvider = create_provider(
|
|
||||||
name="oidc",
|
|
||||||
discovery_document=discovery,
|
|
||||||
response_convertor=response_convertor,
|
|
||||||
)
|
|
||||||
generic_sso = SSOProvider(
|
|
||||||
client_id=generic_client_id,
|
|
||||||
client_secret=generic_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
scope=generic_scope,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
|
||||||
result = await generic_sso.verify_and_process(
|
|
||||||
request, params={"include_client_id": generic_include_client_id}
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug("generic result: %s", result)
|
|
||||||
|
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
|
||||||
user_email: Optional[str] = getattr(result, "email", None)
|
|
||||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
|
||||||
|
|
||||||
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
|
||||||
email_domain = user_email.split("@")[1]
|
|
||||||
allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore
|
|
||||||
if email_domain not in allowed_domains:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail={
|
|
||||||
"message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format(
|
|
||||||
email_domain, allowed_domains
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# generic client id
|
|
||||||
if generic_client_id is not None and result is not None:
|
|
||||||
user_id = getattr(result, "id", None)
|
|
||||||
user_email = getattr(result, "email", None)
|
|
||||||
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
|
||||||
|
|
||||||
if user_id is None and result is not None:
|
|
||||||
_first_name = getattr(result, "first_name", "") or ""
|
|
||||||
_last_name = getattr(result, "last_name", "") or ""
|
|
||||||
user_id = _first_name + _last_name
|
|
||||||
|
|
||||||
if user_email is not None and (user_id is None or len(user_id) == 0):
|
|
||||||
user_id = user_email
|
|
||||||
|
|
||||||
user_info = None
|
|
||||||
user_id_models: List = []
|
|
||||||
max_internal_user_budget = litellm.max_internal_user_budget
|
|
||||||
internal_user_budget_duration = litellm.internal_user_budget_duration
|
|
||||||
|
|
||||||
# User might not be already created on first generation of key
|
|
||||||
# But if it is, we want their models preferences
|
|
||||||
default_ui_key_values = {
|
|
||||||
"duration": "24hr",
|
|
||||||
"key_max_budget": 0.01,
|
|
||||||
"aliases": {},
|
|
||||||
"config": {},
|
|
||||||
"spend": 0,
|
|
||||||
"team_id": "litellm-dashboard",
|
|
||||||
}
|
|
||||||
user_defined_values: Optional[SSOUserDefinedValues] = None
|
|
||||||
if user_id is not None:
|
|
||||||
user_defined_values = SSOUserDefinedValues(
|
|
||||||
models=user_id_models,
|
|
||||||
user_id=user_id,
|
|
||||||
user_email=user_email,
|
|
||||||
max_budget=max_internal_user_budget,
|
|
||||||
user_role=None,
|
|
||||||
budget_duration=internal_user_budget_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
_user_id_from_sso = user_id
|
|
||||||
user_role = None
|
|
||||||
try:
|
|
||||||
if prisma_client is not None:
|
|
||||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
|
||||||
)
|
|
||||||
if user_info is None:
|
|
||||||
## check if user-email in db ##
|
|
||||||
user_info = await prisma_client.db.litellm_usertable.find_first(
|
|
||||||
where={"user_email": user_email}
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_info is not None and user_id is not None:
|
|
||||||
user_defined_values = SSOUserDefinedValues(
|
|
||||||
models=getattr(user_info, "models", user_id_models),
|
|
||||||
user_id=user_id,
|
|
||||||
user_email=getattr(user_info, "user_email", user_email),
|
|
||||||
user_role=getattr(user_info, "user_role", None),
|
|
||||||
max_budget=getattr(
|
|
||||||
user_info, "max_budget", max_internal_user_budget
|
|
||||||
),
|
|
||||||
budget_duration=getattr(
|
|
||||||
user_info, "budget_duration", internal_user_budget_duration
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_role = getattr(user_info, "user_role", None)
|
|
||||||
|
|
||||||
# update id
|
|
||||||
await prisma_client.db.litellm_usertable.update_many(
|
|
||||||
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
|
||||||
)
|
|
||||||
elif litellm.default_user_params is not None and isinstance(
|
|
||||||
litellm.default_user_params, dict
|
|
||||||
):
|
|
||||||
user_defined_values = {
|
|
||||||
"models": litellm.default_user_params.get("models", user_id_models),
|
|
||||||
"user_id": litellm.default_user_params.get("user_id", user_id),
|
|
||||||
"user_email": litellm.default_user_params.get(
|
|
||||||
"user_email", user_email
|
|
||||||
),
|
|
||||||
"user_role": litellm.default_user_params.get("user_role", None),
|
|
||||||
"max_budget": litellm.default_user_params.get(
|
|
||||||
"max_budget", max_internal_user_budget
|
|
||||||
),
|
|
||||||
"budget_duration": litellm.default_user_params.get(
|
|
||||||
"budget_duration", internal_user_budget_duration
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if user_defined_values is None:
|
|
||||||
raise Exception(
|
|
||||||
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
|
||||||
)
|
|
||||||
|
|
||||||
is_internal_user = False
|
|
||||||
if (
|
|
||||||
user_defined_values["user_role"] is not None
|
|
||||||
and user_defined_values["user_role"] == LitellmUserRoles.INTERNAL_USER.value
|
|
||||||
):
|
|
||||||
is_internal_user = True
|
|
||||||
if (
|
|
||||||
is_internal_user is True
|
|
||||||
and user_defined_values["max_budget"] is None
|
|
||||||
and litellm.max_internal_user_budget is not None
|
|
||||||
):
|
|
||||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
|
||||||
|
|
||||||
if (
|
|
||||||
is_internal_user is True
|
|
||||||
and user_defined_values["budget_duration"] is None
|
|
||||||
and litellm.internal_user_budget_duration is not None
|
|
||||||
):
|
|
||||||
user_defined_values["budget_duration"] = litellm.internal_user_budget_duration
|
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"user_defined_values for creating ui key: {user_defined_values}"
|
|
||||||
)
|
|
||||||
|
|
||||||
default_ui_key_values.update(user_defined_values)
|
|
||||||
default_ui_key_values["request_type"] = "key"
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
**default_ui_key_values, # type: ignore
|
|
||||||
)
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
user_id = response["user_id"] # type: ignore
|
|
||||||
|
|
||||||
# This should always be true
|
|
||||||
# User_id on SSO == user_id in the LiteLLM_VerificationToken Table
|
|
||||||
assert user_id == _user_id_from_sso
|
|
||||||
litellm_dashboard_ui = "/ui/"
|
|
||||||
user_role = user_role or "app_owner"
|
|
||||||
if (
|
|
||||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
|
||||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
|
||||||
):
|
|
||||||
# checks if user is admin
|
|
||||||
user_role = "app_admin"
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
|
||||||
)
|
|
||||||
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
|
||||||
if ui_access_mode == "admin_only" and "admin" not in user_role:
|
|
||||||
verbose_proxy_logger.debug("EXCEPTION RAISED")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail={
|
|
||||||
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
jwt_token = jwt.encode( # type: ignore
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"key": key,
|
|
||||||
"user_email": user_email,
|
|
||||||
"user_role": user_role,
|
|
||||||
"login_method": "sso",
|
|
||||||
"premium_user": premium_user,
|
|
||||||
"auth_header_name": general_settings.get(
|
|
||||||
"litellm_key_header_name", "Authorization"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
master_key,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
|
||||||
if user_id is not None and isinstance(user_id, str):
|
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
|
||||||
return redirect_response
|
|
||||||
|
|
||||||
|
|
||||||
#### INVITATION MANAGEMENT ####
|
#### INVITATION MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
|
@ -10044,7 +9479,7 @@ async def shutdown_event():
|
||||||
|
|
||||||
|
|
||||||
def cleanup_router_config_variables():
|
def cleanup_router_config_variables():
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
||||||
|
|
||||||
# Set all variables to None
|
# Set all variables to None
|
||||||
master_key = None
|
master_key = None
|
||||||
|
@ -10053,6 +9488,7 @@ def cleanup_router_config_variables():
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_auth_path = None
|
user_custom_auth_path = None
|
||||||
user_custom_key_generate = None
|
user_custom_key_generate = None
|
||||||
|
user_custom_sso = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
health_check_details = None
|
health_check_details = None
|
||||||
|
@ -10071,6 +9507,7 @@ app.include_router(health_router)
|
||||||
app.include_router(key_management_router)
|
app.include_router(key_management_router)
|
||||||
app.include_router(internal_user_router)
|
app.include_router(internal_user_router)
|
||||||
app.include_router(team_router)
|
app.include_router(team_router)
|
||||||
|
app.include_router(ui_sso_router)
|
||||||
app.include_router(spend_management_router)
|
app.include_router(spend_management_router)
|
||||||
app.include_router(caching_router)
|
app.include_router(caching_router)
|
||||||
app.include_router(analytics_router)
|
app.include_router(analytics_router)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue