feat(proxy_server.py): working /team/new endpoint for creating a new team

This commit is contained in:
Krrish Dholakia 2024-02-14 17:20:41 -08:00
parent d589061a9c
commit c094db7160
5 changed files with 383 additions and 276 deletions

View file

@ -21,10 +21,10 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs):
try:
return self.model_dump() # noqa
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict()
return self.dict(**kwargs)
def fields_set(self):
try:
@ -211,6 +211,34 @@ class UpdateUserRequest(GenerateRequestBase):
max_budget: Optional[float] = None
class NewTeamRequest(LiteLLMBase):
team_alias: Optional[str] = None
team_id: Optional[str] = None
admins: list = []
members: list = []
metadata: Optional[dict] = None
class LiteLLM_TeamTable(NewTeamRequest):
max_budget: Optional[float] = None
spend: Optional[float] = None
models: list = []
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
class NewTeamResponse(LiteLLMBase):
team_id: str
admins: list
members: list
metadata: dict
created_at: datetime
updated_at: datetime
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"

View file

@ -3573,271 +3573,6 @@ async def user_auth(request: Request):
return "Email sent!"
@app.get("/sso/key/generate", tags=["experimental"])
async def google_login(request: Request):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
ui_username = os.getenv("UI_USERNAME")
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
with google_sso:
return await google_sso.get_login_redirect()
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
else:
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
global master_key
form = await request.form()
username = str(form.get("username"))
password = str(form.get("password"))
ui_username = os.getenv("UI_USERNAME", "admin")
ui_password = os.getenv("UI_PASSWORD", None)
if ui_password is None:
ui_password = str(master_key) if master_key is not None else None
if ui_password is None:
raise ProxyException(
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
type="auth_error",
param="UI_PASSWORD",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
password, ui_password
):
user_role = "app_owner"
user_id = username
key_user_id = user_id
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
) or user_id == "admin":
# checks if user is admin
user_role = "app_admin"
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
# Admin is Authe'd in - generate key for the UI to access Proxy
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
)
else:
response = {
"token": "sk-gm",
"user_id": "litellm-dashboard",
}
key = response["token"] # type: ignore
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_id,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
)
@app.get("/sso/callback", tags=["experimental"])
async def auth_callback(request: Request):
"""Verify login"""
global general_settings
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
result = await microsoft_sso.verify_and_process(request)
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "/ui/"
user_role = "app_owner"
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
):
# checks if user is admin
user_role = "app_admin"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_email,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui)
@router.get(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
@ -3953,24 +3688,51 @@ async def user_update(data: UpdateUserRequest):
@router.post(
"/team/new", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
"/team/new",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewTeamResponse,
)
async def new_team():
async def new_team(
data: NewTeamRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new team
Allow users to create a new team. Apply user permissions to their team.
Parameters:
- team_alias: Optional[str] - User defined team alias
- team_id: Optional[str] - The team id of the user. If none passed, we'll generate it.
- team_admins: list - A list of user IDs that will be owning the team
- admins: list - A list of user IDs that will be owning the team
- members: list - A list of user IDs that will be members of the team
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns:
- key: (str) The generated api key
- expires: (datetime) Datetime object for when key expires.
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
"""
pass
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_id is None:
data.team_id = str(uuid.uuid4())
complete_team_data = LiteLLM_TeamTable(
**data.json(),
max_budget=user_api_key_dict.max_budget,
models=user_api_key_dict.models,
max_parallel_requests=user_api_key_dict.max_parallel_requests,
tpm_limit=user_api_key_dict.tpm_limit,
rpm_limit=user_api_key_dict.rpm_limit,
budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at,
)
team_row = await prisma_client.insert_data(
data=complete_team_data.json(exclude_none=True), table_name="team"
)
return team_row
@router.post(
@ -4354,6 +4116,274 @@ async def retrieve_server_log(request: Request):
return FileResponse(filepath)
#### LOGIN ENDPOINTS ####
@app.get("/sso/key/generate", tags=["experimental"])
async def google_login(request: Request):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
ui_username = os.getenv("UI_USERNAME")
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
with google_sso:
return await google_sso.get_login_redirect()
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
else:
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
global master_key
form = await request.form()
username = str(form.get("username"))
password = str(form.get("password"))
ui_username = os.getenv("UI_USERNAME", "admin")
ui_password = os.getenv("UI_PASSWORD", None)
if ui_password is None:
ui_password = str(master_key) if master_key is not None else None
if ui_password is None:
raise ProxyException(
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
type="auth_error",
param="UI_PASSWORD",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
password, ui_password
):
user_role = "app_owner"
user_id = username
key_user_id = user_id
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
) or user_id == "admin":
# checks if user is admin
user_role = "app_admin"
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
# Admin is Authe'd in - generate key for the UI to access Proxy
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
)
else:
response = {
"token": "sk-gm",
"user_id": "litellm-dashboard",
}
key = response["token"] # type: ignore
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_id,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
)
@app.get("/sso/callback", tags=["experimental"])
async def auth_callback(request: Request):
"""Verify login"""
global general_settings
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
result = await microsoft_sso.verify_and_process(request)
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "/ui/"
user_role = "app_owner"
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
):
# checks if user is admin
user_role = "app_admin"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_email,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui)
#### BASIC ENDPOINTS ####
@router.post(
"/config/update",

View file

@ -8,6 +8,25 @@ generator client {
provider = "prisma-client-py"
}
// Assign prod keys to groups, not individuals
model LiteLLM_TeamTable {
team_id String @unique
team_alias String?
admins String[]
members String[]
metadata Json @default("{}")
max_budget Float?
spend Float @default(0.0)
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// Track spend, rate limit, budget Users
model LiteLLM_UserTable {
user_id String @unique

View file

@ -677,7 +677,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(
self, data: dict, table_name: Literal["user", "key", "config", "spend"]
self, data: dict, table_name: Literal["user", "key", "config", "spend", "team"]
):
"""
Add a key to the database. If it already exists, do nothing.
@ -713,6 +713,17 @@ class PrismaClient:
)
verbose_proxy_logger.info(f"Data Inserted into User Table")
return new_user_row
elif table_name == "team":
db_data = self.jsonify_object(data=data)
new_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": data["team_id"]},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
verbose_proxy_logger.info(f"Data Inserted into Team Table")
return new_team_row
elif table_name == "config":
"""
For each param,

View file

@ -1,13 +1,32 @@
datasource client {
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
directUrl = env("DIRECT_URL")?
}
generator client {
provider = "prisma-client-py"
}
// Assign prod keys to groups, not individuals
model LiteLLM_TeamTable {
team_id String @unique
team_alias String?
admins String[]
members String[]
metadata Json @default("{}")
max_budget Float?
spend Float @default(0.0)
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// Track spend, rate limit, budget Users
model LiteLLM_UserTable {
user_id String @unique