forked from phoenix/litellm-mirror
fix(proxy_server.py): fix merge conflicts
This commit is contained in:
parent
122ad77d56
commit
5ef33b1d5e
1 changed files with 65 additions and 212 deletions
|
@ -2782,7 +2782,6 @@ async def image_generation(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/moderations",
|
"/v1/moderations",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
@ -3822,6 +3821,7 @@ async def user_update(data: UpdateUserRequest):
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#### TEAM MANAGEMENT ####
|
#### TEAM MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
|
@ -3905,180 +3905,6 @@ async def team_info(
|
||||||
get info on team + related keys
|
get info on team + related keys
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@app.get("/sso/callback", tags=["experimental"])
|
|
||||||
async def auth_callback(request: Request):
|
|
||||||
"""Verify login"""
|
|
||||||
global general_settings
|
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
|
||||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
|
||||||
|
|
||||||
# get url from request
|
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
||||||
|
|
||||||
if redirect_url.endswith("/"):
|
|
||||||
redirect_url += "sso/callback"
|
|
||||||
else:
|
|
||||||
redirect_url += "/sso/callback"
|
|
||||||
|
|
||||||
if google_client_id is not None:
|
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
|
||||||
|
|
||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
||||||
if google_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
google_sso = GoogleSSO(
|
|
||||||
client_id=google_client_id,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
client_secret=google_client_secret,
|
|
||||||
)
|
|
||||||
result = await google_sso.verify_and_process(request)
|
|
||||||
|
|
||||||
elif microsoft_client_id is not None:
|
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
||||||
|
|
||||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
||||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
||||||
if microsoft_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if microsoft_tenant is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="MICROSOFT_TENANT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
|
||||||
client_id=microsoft_client_id,
|
|
||||||
client_secret=microsoft_client_secret,
|
|
||||||
tenant=microsoft_tenant,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
|
||||||
elif generic_client_id is not None:
|
|
||||||
# make generic sso provider
|
|
||||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
||||||
generic_authorization_endpoint = os.getenv(
|
|
||||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
||||||
)
|
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
||||||
if generic_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_authorization_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_token_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_TOKEN_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_userinfo_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
|
||||||
token_endpoint=generic_token_endpoint,
|
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
||||||
generic_sso = SSOProvider(
|
|
||||||
client_id=generic_client_id,
|
|
||||||
client_secret=generic_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
|
||||||
|
|
||||||
request_body = await request.body()
|
|
||||||
|
|
||||||
request_query_params = request.query_params
|
|
||||||
|
|
||||||
# get "code" from query params
|
|
||||||
code = request_query_params.get("code")
|
|
||||||
|
|
||||||
result = await generic_sso.verify_and_process(request)
|
|
||||||
verbose_proxy_logger.debug(f"generic result: {result}")
|
|
||||||
|
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
|
||||||
user_email = getattr(result, "email", None)
|
|
||||||
user_id = getattr(result, "id", None)
|
|
||||||
if user_id is None:
|
|
||||||
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
|
||||||
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
|
|
||||||
)
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
user_id = response["user_id"] # type: ignore
|
|
||||||
|
|
||||||
litellm_dashboard_ui = "/ui/"
|
|
||||||
|
|
||||||
user_role = "app_owner"
|
|
||||||
if (
|
|
||||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
|
||||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
|
||||||
):
|
|
||||||
# checks if user is admin
|
|
||||||
user_role = "app_admin"
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
jwt_token = jwt.encode(
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"key": key,
|
|
||||||
"user_email": user_email,
|
|
||||||
"user_role": user_role,
|
|
||||||
},
|
|
||||||
"secret",
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
|
||||||
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
|
||||||
|
|
||||||
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
|
||||||
general_settings["allow_user_auth"] = True
|
|
||||||
return RedirectResponse(url=litellm_dashboard_ui)
|
|
||||||
|
|
||||||
|
|
||||||
#### MODEL MANAGEMENT ####
|
#### MODEL MANAGEMENT ####
|
||||||
|
@ -4439,17 +4265,14 @@ async def retrieve_server_log(request: Request):
|
||||||
async def google_login(request: Request):
|
async def google_login(request: Request):
|
||||||
"""
|
"""
|
||||||
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||||||
|
|
||||||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
# get url from request
|
# get url from request
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
|
||||||
ui_username = os.getenv("UI_USERNAME")
|
ui_username = os.getenv("UI_USERNAME")
|
||||||
if redirect_url.endswith("/"):
|
if redirect_url.endswith("/"):
|
||||||
redirect_url += "sso/callback"
|
redirect_url += "sso/callback"
|
||||||
|
@ -4467,20 +4290,16 @@ async def google_login(request: Request):
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
google_sso = GoogleSSO(
|
google_sso = GoogleSSO(
|
||||||
client_id=google_client_id,
|
client_id=google_client_id,
|
||||||
client_secret=google_client_secret,
|
client_secret=google_client_secret,
|
||||||
redirect_uri=redirect_url,
|
redirect_uri=redirect_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with google_sso:
|
with google_sso:
|
||||||
return await google_sso.get_login_redirect()
|
return await google_sso.get_login_redirect()
|
||||||
|
|
||||||
# Microsoft SSO Auth
|
# Microsoft SSO Auth
|
||||||
elif microsoft_client_id is not None:
|
elif microsoft_client_id is not None:
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
@ -4494,7 +4313,6 @@ async def google_login(request: Request):
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
microsoft_sso = MicrosoftSSO(
|
||||||
client_id=microsoft_client_id,
|
client_id=microsoft_client_id,
|
||||||
client_secret=microsoft_client_secret,
|
client_secret=microsoft_client_secret,
|
||||||
|
@ -4541,21 +4359,17 @@ async def google_login(request: Request):
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
discovery = DiscoveryDocument(
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
token_endpoint=generic_token_endpoint,
|
token_endpoint=generic_token_endpoint,
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
generic_sso = SSOProvider(
|
generic_sso = SSOProvider(
|
||||||
client_id=generic_client_id,
|
client_id=generic_client_id,
|
||||||
|
@ -4563,14 +4377,8 @@ async def google_login(request: Request):
|
||||||
redirect_uri=redirect_url,
|
redirect_uri=redirect_url,
|
||||||
allow_insecure_http=True,
|
allow_insecure_http=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with generic_sso:
|
with generic_sso:
|
||||||
return await generic_sso.get_login_redirect()
|
return await generic_sso.get_login_redirect()
|
||||||
|
|
||||||
elif ui_username is not None:
|
|
||||||
# No Google, Microsoft SSO
|
|
||||||
# Use UI Credentials set in .env
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
elif ui_username is not None:
|
elif ui_username is not None:
|
||||||
# No Google, Microsoft SSO
|
# No Google, Microsoft SSO
|
||||||
# Use UI Credentials set in .env
|
# Use UI Credentials set in .env
|
||||||
|
@ -4599,7 +4407,6 @@ async def login(request: Request):
|
||||||
ui_password = os.getenv("UI_PASSWORD", None)
|
ui_password = os.getenv("UI_PASSWORD", None)
|
||||||
if ui_password is None:
|
if ui_password is None:
|
||||||
ui_password = str(master_key) if master_key is not None else None
|
ui_password = str(master_key) if master_key is not None else None
|
||||||
|
|
||||||
if ui_password is None:
|
if ui_password is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
|
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
|
||||||
|
@ -4607,7 +4414,6 @@ async def login(request: Request):
|
||||||
param="UI_PASSWORD",
|
param="UI_PASSWORD",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
|
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
|
||||||
password, ui_password
|
password, ui_password
|
||||||
):
|
):
|
||||||
|
@ -4621,9 +4427,7 @@ async def login(request: Request):
|
||||||
# checks if user is admin
|
# checks if user is admin
|
||||||
user_role = "app_admin"
|
user_role = "app_admin"
|
||||||
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
|
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
|
||||||
|
|
||||||
# Admin is Authe'd in - generate key for the UI to access Proxy
|
# Admin is Authe'd in - generate key for the UI to access Proxy
|
||||||
|
|
||||||
if os.getenv("DATABASE_URL") is not None:
|
if os.getenv("DATABASE_URL") is not None:
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
|
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
|
||||||
|
@ -4633,11 +4437,8 @@ async def login(request: Request):
|
||||||
"token": "sk-gm",
|
"token": "sk-gm",
|
||||||
"user_id": "litellm-dashboard",
|
"user_id": "litellm-dashboard",
|
||||||
}
|
}
|
||||||
|
|
||||||
key = response["token"] # type: ignore
|
key = response["token"] # type: ignore
|
||||||
|
|
||||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
jwt_token = jwt.encode(
|
jwt_token = jwt.encode(
|
||||||
|
@ -4651,7 +4452,6 @@ async def login(request: Request):
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
||||||
|
|
||||||
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
||||||
general_settings["allow_user_auth"] = True
|
general_settings["allow_user_auth"] = True
|
||||||
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
|
@ -4670,15 +4470,13 @@ async def auth_callback(request: Request):
|
||||||
global general_settings
|
global general_settings
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
# get url from request
|
# get url from request
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
|
||||||
if redirect_url.endswith("/"):
|
if redirect_url.endswith("/"):
|
||||||
redirect_url += "sso/callback"
|
redirect_url += "sso/callback"
|
||||||
else:
|
else:
|
||||||
redirect_url += "/sso/callback"
|
redirect_url += "/sso/callback"
|
||||||
|
|
||||||
if google_client_id is not None:
|
if google_client_id is not None:
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
@ -4696,7 +4494,6 @@ async def auth_callback(request: Request):
|
||||||
client_secret=google_client_secret,
|
client_secret=google_client_secret,
|
||||||
)
|
)
|
||||||
result = await google_sso.verify_and_process(request)
|
result = await google_sso.verify_and_process(request)
|
||||||
|
|
||||||
elif microsoft_client_id is not None:
|
elif microsoft_client_id is not None:
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
@ -4716,7 +4513,6 @@ async def auth_callback(request: Request):
|
||||||
param="MICROSOFT_TENANT",
|
param="MICROSOFT_TENANT",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
microsoft_sso = MicrosoftSSO(
|
microsoft_sso = MicrosoftSSO(
|
||||||
client_id=microsoft_client_id,
|
client_id=microsoft_client_id,
|
||||||
client_secret=microsoft_client_secret,
|
client_secret=microsoft_client_secret,
|
||||||
|
@ -4725,21 +4521,80 @@ async def auth_callback(request: Request):
|
||||||
allow_insecure_http=True,
|
allow_insecure_http=True,
|
||||||
)
|
)
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
# make generic sso provider
|
||||||
|
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
||||||
|
request_body = await request.body()
|
||||||
|
request_query_params = request.query_params
|
||||||
|
# get "code" from query params
|
||||||
|
code = request_query_params.get("code")
|
||||||
|
result = await generic_sso.verify_and_process(request)
|
||||||
|
verbose_proxy_logger.debug(f"generic result: {result}")
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
user_email = getattr(result, "email", None)
|
user_email = getattr(result, "email", None)
|
||||||
user_id = getattr(result, "id", None)
|
user_id = getattr(result, "id", None)
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||||
|
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
|
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
|
||||||
)
|
)
|
||||||
key = response["token"] # type: ignore
|
key = response["token"] # type: ignore
|
||||||
user_id = response["user_id"] # type: ignore
|
user_id = response["user_id"] # type: ignore
|
||||||
|
|
||||||
litellm_dashboard_ui = "/ui/"
|
litellm_dashboard_ui = "/ui/"
|
||||||
|
|
||||||
user_role = "app_owner"
|
user_role = "app_owner"
|
||||||
if (
|
if (
|
||||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
|
@ -4747,7 +4602,6 @@ async def auth_callback(request: Request):
|
||||||
):
|
):
|
||||||
# checks if user is admin
|
# checks if user is admin
|
||||||
user_role = "app_admin"
|
user_role = "app_admin"
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
jwt_token = jwt.encode(
|
jwt_token = jwt.encode(
|
||||||
|
@ -4761,7 +4615,6 @@ async def auth_callback(request: Request):
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
||||||
|
|
||||||
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
||||||
general_settings["allow_user_auth"] = True
|
general_settings["allow_user_auth"] = True
|
||||||
return RedirectResponse(url=litellm_dashboard_ui)
|
return RedirectResponse(url=litellm_dashboard_ui)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue