Merge pull request #1979 from BerriAI/litellm_team_settings

refactor(proxy_server.py): initial stubbed endpoints for team management
This commit is contained in:
Krish Dholakia 2024-02-14 21:47:14 -08:00 committed by GitHub
commit 15bd6266b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 625 additions and 321 deletions

View file

@ -10,6 +10,7 @@
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx

View file

@ -21,10 +21,10 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs):
try:
return self.model_dump() # noqa
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict()
return self.dict(**kwargs)
def fields_set(self):
try:
@ -211,6 +211,38 @@ class UpdateUserRequest(GenerateRequestBase):
max_budget: Optional[float] = None
class NewTeamRequest(LiteLLMBase):
team_alias: Optional[str] = None
team_id: Optional[str] = None
admins: list = []
members: list = []
metadata: Optional[dict] = None
class LiteLLM_TeamTable(NewTeamRequest):
max_budget: Optional[float] = None
spend: Optional[float] = None
models: list = []
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
class NewTeamResponse(LiteLLMBase):
team_id: str
admins: list
members: list
metadata: dict
created_at: datetime
updated_at: datetime
class TeamRequest(LiteLLMBase):
teams: List[str]
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"

View file

@ -2799,6 +2799,7 @@ async def image_generation(
)
@router.post(
"/v1/moderations",
dependencies=[Depends(user_api_key_auth)],
@ -2973,7 +2974,8 @@ async def generate_key_fn(
Parameters:
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- key_alias: Optional[str] - User defined key alias
- team_id: Optional[str] - The team id of the user
- team_id: Optional[str] - The team id of the key
- user_id: Optional[str] - The user id of the key
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
@ -3202,8 +3204,8 @@ async def info_key_fn_v2(
Example Curl:
```
curl -X GET "http://0.0.0.0:8000/key/info" \
-H "Authorization: Bearer sk-1234" \
-d {"keys": ["sk-1", "sk-2", "sk-3"]}
-H "Authorization: Bearer sk-1234" \
-d {"keys": ["sk-1", "sk-2", "sk-3"]}
```
"""
global prisma_client
@ -3313,6 +3315,9 @@ async def info_key_fn(
)
#### SPEND MANAGEMENT #####
@router.get(
"/spend/keys",
tags=["budget & spend Tracking"],
@ -3724,232 +3729,200 @@ async def user_auth(request: Request):
return "Email sent!"
@app.get("/sso/key/generate", tags=["experimental"])
async def google_login(request: Request):
@router.get(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
async def user_info(
user_id: Optional[str] = fastapi.Query(
default=None, description="User ID in the request parameters"
)
):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
Use this to get user information. (user row + all user key info)
Example request
```
curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \
--header 'Authorization: Bearer sk-1234'
```
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
ui_username = os.getenv("UI_USERNAME")
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
## GET USER ROW ##
if user_id is not None:
user_info = await prisma_client.get_data(user_id=user_id)
else:
user_info = None
## GET ALL KEYS ##
keys = await prisma_client.get_data(
user_id=user_id,
table_name="key",
query_type="find_all",
expires=datetime.now(),
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
if user_info is None:
## make sure we still return a total spend ##
spend = 0
for k in keys:
spend += getattr(k, "spend", 0)
user_info = {"spend": spend}
with google_sso:
return await google_sso.get_login_redirect()
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
## REMOVE HASHED TOKEN INFO before returning ##
for key in keys:
try:
key = key.model_dump() # noqa
except:
# if using pydantic v1
key = key.dict()
key.pop("token", None)
return {"user_id": user_id, "user_info": user_info, "keys": keys}
except Exception as e:
if isinstance(e, HTTPException):
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with generic_sso:
return await generic_sso.get_login_redirect()
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
else:
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
async def user_update(data: UpdateUserRequest):
"""
[TODO]: Use this to update user budget
"""
global prisma_client
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
global master_key
form = await request.form()
username = str(form.get("username"))
password = str(form.get("password"))
ui_username = os.getenv("UI_USERNAME", "admin")
ui_password = os.getenv("UI_PASSWORD", None)
if ui_password is None:
ui_password = str(master_key) if master_key is not None else None
data_json: dict = data.json()
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
if ui_password is None:
raise ProxyException(
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
type="auth_error",
param="UI_PASSWORD",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
non_default_values = {k: v for k, v in data_json.items() if v is not None}
response = await prisma_client.update_data(
user_id=data_json["user_id"],
data=non_default_values,
update_key_values=non_default_values,
)
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
password, ui_password
):
user_role = "app_owner"
user_id = username
key_user_id = user_id
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
) or user_id == "admin":
# checks if user is admin
user_role = "app_admin"
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
# Admin is Authe'd in - generate key for the UI to access Proxy
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
return {"user_id": data_json["user_id"], **non_default_values}
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
response = {
"token": "sk-gm",
"user_id": "litellm-dashboard",
}
key = response["token"] # type: ignore
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_id,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
else:
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
message="Authentication Error, " + str(e),
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
#### TEAM MANAGEMENT ####
@router.post(
"/team/new",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewTeamResponse,
)
async def new_team(
data: NewTeamRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Allow users to create a new team. Apply user permissions to their team.
Parameters:
- team_alias: Optional[str] - User defined team alias
- team_id: Optional[str] - The team id of the user. If none passed, we'll generate it.
- admins: list - A list of user IDs that will be owning the team
- members: list - A list of user IDs that will be members of the team
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns:
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_id is None:
data.team_id = str(uuid.uuid4())
complete_team_data = LiteLLM_TeamTable(
**data.json(),
max_budget=user_api_key_dict.max_budget,
models=user_api_key_dict.models,
max_parallel_requests=user_api_key_dict.max_parallel_requests,
tpm_limit=user_api_key_dict.tpm_limit,
rpm_limit=user_api_key_dict.rpm_limit,
budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at,
)
team_row = await prisma_client.insert_data(
data=complete_team_data.json(exclude_none=True), table_name="team"
)
return team_row
@router.post(
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
async def update_team():
"""
update team and members
"""
pass
@router.post(
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
async def delete_team():
"""
delete team and team keys
"""
pass
@router.get(
"/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
async def team_info(
team_id: str = fastapi.Query(
default=None, description="Team ID in the request parameters"
)
):
"""
get info on team + related keys
"""
pass
@app.get("/sso/callback", tags=["experimental"])
async def auth_callback(request: Request):
"""Verify login"""
@ -4125,117 +4098,6 @@ async def auth_callback(request: Request):
return RedirectResponse(url=litellm_dashboard_ui)
@router.get(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
async def user_info(
user_id: Optional[str] = fastapi.Query(
default=None, description="User ID in the request parameters"
)
):
"""
Use this to get user information. (user row + all user key info)
Example request
```
curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \
--header 'Authorization: Bearer sk-1234'
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
## GET USER ROW ##
if user_id is not None:
user_info = await prisma_client.get_data(user_id=user_id)
else:
user_info = None
## GET ALL KEYS ##
keys = await prisma_client.get_data(
user_id=user_id,
table_name="key",
query_type="find_all",
expires=datetime.now(),
)
if user_info is None:
## make sure we still return a total spend ##
spend = 0
for k in keys:
spend += getattr(k, "spend", 0)
user_info = {"spend": spend}
## REMOVE HASHED TOKEN INFO before returning ##
for key in keys:
try:
key = key.model_dump() # noqa
except:
# if using pydantic v1
key = key.dict()
key.pop("token", None)
return {"user_id": user_id, "user_info": user_info, "keys": keys}
except Exception as e:
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@router.post(
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
async def user_update(data: UpdateUserRequest):
"""
[TODO]: Use this to update user budget
"""
global prisma_client
try:
data_json: dict = data.json()
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
non_default_values = {k: v for k, v in data_json.items() if v is not None}
response = await prisma_client.update_data(
user_id=data_json["user_id"],
data=non_default_values,
update_key_values=non_default_values,
)
return {"user_id": data_json["user_id"], **non_default_values}
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
#### MODEL MANAGEMENT ####
@ -4587,6 +4449,341 @@ async def retrieve_server_log(request: Request):
return FileResponse(filepath)
#### LOGIN ENDPOINTS ####
@app.get("/sso/key/generate", tags=["experimental"])
async def google_login(request: Request):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
ui_username = os.getenv("UI_USERNAME")
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
with google_sso:
return await google_sso.get_login_redirect()
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with generic_sso:
return await generic_sso.get_login_redirect()
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
else:
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
global master_key
form = await request.form()
username = str(form.get("username"))
password = str(form.get("password"))
ui_username = os.getenv("UI_USERNAME", "admin")
ui_password = os.getenv("UI_PASSWORD", None)
if ui_password is None:
ui_password = str(master_key) if master_key is not None else None
if ui_password is None:
raise ProxyException(
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
type="auth_error",
param="UI_PASSWORD",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
password, ui_password
):
user_role = "app_owner"
user_id = username
key_user_id = user_id
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
) or user_id == "admin":
# checks if user is admin
user_role = "app_admin"
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
# Admin is Authe'd in - generate key for the UI to access Proxy
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
)
else:
response = {
"token": "sk-gm",
"user_id": "litellm-dashboard",
}
key = response["token"] # type: ignore
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_id,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
)
@app.get("/sso/callback", tags=["experimental"])
async def auth_callback(request: Request):
"""Verify login"""
global general_settings
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
result = await microsoft_sso.verify_and_process(request)
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "/ui/"
user_role = "app_owner"
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
):
# checks if user is admin
user_role = "app_admin"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_email,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui)
#### BASIC ENDPOINTS ####
@router.post(
"/config/update",

View file

@ -8,6 +8,25 @@ generator client {
provider = "prisma-client-py"
}
// Assign prod keys to groups, not individuals
model LiteLLM_TeamTable {
team_id String @unique
team_alias String?
admins String[]
members String[]
metadata Json @default("{}")
max_budget Float?
spend Float @default(0.0)
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// Track spend, rate limit, budget Users
model LiteLLM_UserTable {
user_id String @unique

View file

@ -511,8 +511,9 @@ class PrismaClient:
token: Optional[Union[str, list]] = None,
user_id: Optional[str] = None,
user_id_list: Optional[list] = None,
team_id: Optional[str] = None,
key_val: Optional[dict] = None,
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None,
@ -547,6 +548,14 @@ class PrismaClient:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif query_type == "find_all" and team_id is not None:
response = await self.db.litellm_verificationtoken.find_many(
where={"team_id": team_id}
)
if response is not None and len(response) > 0:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif (
query_type == "find_all"
and expires is not None
@ -659,7 +668,23 @@ class PrismaClient:
order={"startTime": "desc"},
)
return response
elif table_name == "team":
if query_type == "find_unique":
response = await self.db.litellm_teamtable.find_unique(
where={"team_id": team_id} # type: ignore
)
if query_type == "find_all" and team_id is not None:
user_id_values = str(tuple(team_id))
sql_query = f"""
SELECT *
FROM "LiteLLM_TeamTable"
WHERE "team_id" IN {team_id}
"""
# Execute the raw query
# The asterisk before `team_id` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback
@ -679,7 +704,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(
self, data: dict, table_name: Literal["user", "key", "config", "spend"]
self, data: dict, table_name: Literal["user", "key", "config", "spend", "team"]
):
"""
Add a key to the database. If it already exists, do nothing.
@ -715,6 +740,17 @@ class PrismaClient:
)
verbose_proxy_logger.info(f"Data Inserted into User Table")
return new_user_row
elif table_name == "team":
db_data = self.jsonify_object(data=data)
new_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": data["team_id"]},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
verbose_proxy_logger.info(f"Data Inserted into Team Table")
return new_team_row
elif table_name == "config":
"""
For each param,

View file

@ -8,6 +8,25 @@ generator client {
provider = "prisma-client-py"
}
// Assign prod keys to groups, not individuals
model LiteLLM_TeamTable {
team_id String @unique
team_alias String?
admins String[]
members String[]
metadata Json @default("{}")
max_budget Float?
spend Float @default(0.0)
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// Track spend, rate limit, budget Users
model LiteLLM_UserTable {
user_id String @unique