forked from phoenix/litellm-mirror
Merge pull request #1979 from BerriAI/litellm_team_settings
refactor(proxy_server.py): initial stubbed endpoints for team management
This commit is contained in:
commit
15bd6266b2
6 changed files with 625 additions and 321 deletions
|
@ -10,6 +10,7 @@
|
||||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
|
|
|
@ -21,10 +21,10 @@ class LiteLLMBase(BaseModel):
|
||||||
|
|
||||||
def json(self, **kwargs):
|
def json(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
return self.model_dump() # noqa
|
return self.model_dump(**kwargs) # noqa
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
return self.dict()
|
return self.dict(**kwargs)
|
||||||
|
|
||||||
def fields_set(self):
|
def fields_set(self):
|
||||||
try:
|
try:
|
||||||
|
@ -211,6 +211,38 @@ class UpdateUserRequest(GenerateRequestBase):
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class NewTeamRequest(LiteLLMBase):
|
||||||
|
team_alias: Optional[str] = None
|
||||||
|
team_id: Optional[str] = None
|
||||||
|
admins: list = []
|
||||||
|
members: list = []
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_TeamTable(NewTeamRequest):
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
spend: Optional[float] = None
|
||||||
|
models: list = []
|
||||||
|
max_parallel_requests: Optional[int] = None
|
||||||
|
tpm_limit: Optional[int] = None
|
||||||
|
rpm_limit: Optional[int] = None
|
||||||
|
budget_duration: Optional[str] = None
|
||||||
|
budget_reset_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class NewTeamResponse(LiteLLMBase):
|
||||||
|
team_id: str
|
||||||
|
admins: list
|
||||||
|
members: list
|
||||||
|
metadata: dict
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class TeamRequest(LiteLLMBase):
|
||||||
|
teams: List[str]
|
||||||
|
|
||||||
|
|
||||||
class KeyManagementSystem(enum.Enum):
|
class KeyManagementSystem(enum.Enum):
|
||||||
GOOGLE_KMS = "google_kms"
|
GOOGLE_KMS = "google_kms"
|
||||||
AZURE_KEY_VAULT = "azure_key_vault"
|
AZURE_KEY_VAULT = "azure_key_vault"
|
||||||
|
|
|
@ -2799,6 +2799,7 @@ async def image_generation(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/moderations",
|
"/v1/moderations",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
@ -2973,7 +2974,8 @@ async def generate_key_fn(
|
||||||
Parameters:
|
Parameters:
|
||||||
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
- key_alias: Optional[str] - User defined key alias
|
- key_alias: Optional[str] - User defined key alias
|
||||||
- team_id: Optional[str] - The team id of the user
|
- team_id: Optional[str] - The team id of the key
|
||||||
|
- user_id: Optional[str] - The user id of the key
|
||||||
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
|
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
|
||||||
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||||
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||||
|
@ -3202,8 +3204,8 @@ async def info_key_fn_v2(
|
||||||
Example Curl:
|
Example Curl:
|
||||||
```
|
```
|
||||||
curl -X GET "http://0.0.0.0:8000/key/info" \
|
curl -X GET "http://0.0.0.0:8000/key/info" \
|
||||||
-H "Authorization: Bearer sk-1234" \
|
-H "Authorization: Bearer sk-1234" \
|
||||||
-d {"keys": ["sk-1", "sk-2", "sk-3"]}
|
-d {"keys": ["sk-1", "sk-2", "sk-3"]}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
@ -3313,6 +3315,9 @@ async def info_key_fn(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#### SPEND MANAGEMENT #####
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/spend/keys",
|
"/spend/keys",
|
||||||
tags=["budget & spend Tracking"],
|
tags=["budget & spend Tracking"],
|
||||||
|
@ -3724,232 +3729,200 @@ async def user_auth(request: Request):
|
||||||
return "Email sent!"
|
return "Email sent!"
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sso/key/generate", tags=["experimental"])
|
@router.get(
|
||||||
async def google_login(request: Request):
|
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
|
)
|
||||||
|
async def user_info(
|
||||||
|
user_id: Optional[str] = fastapi.Query(
|
||||||
|
default=None, description="User ID in the request parameters"
|
||||||
|
)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
Use this to get user information. (user row + all user key info)
|
||||||
|
|
||||||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
|
||||||
Example:
|
|
||||||
|
|
||||||
|
Example request
|
||||||
|
```
|
||||||
|
curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \
|
||||||
|
--header 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
global prisma_client
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
try:
|
||||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
# get url from request
|
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
||||||
|
|
||||||
ui_username = os.getenv("UI_USERNAME")
|
|
||||||
if redirect_url.endswith("/"):
|
|
||||||
redirect_url += "sso/callback"
|
|
||||||
else:
|
|
||||||
redirect_url += "/sso/callback"
|
|
||||||
# Google SSO Auth
|
|
||||||
if google_client_id is not None:
|
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
|
||||||
|
|
||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
||||||
if google_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GOOGLE_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
)
|
||||||
|
## GET USER ROW ##
|
||||||
google_sso = GoogleSSO(
|
if user_id is not None:
|
||||||
client_id=google_client_id,
|
user_info = await prisma_client.get_data(user_id=user_id)
|
||||||
client_secret=google_client_secret,
|
else:
|
||||||
redirect_uri=redirect_url,
|
user_info = None
|
||||||
|
## GET ALL KEYS ##
|
||||||
|
keys = await prisma_client.get_data(
|
||||||
|
user_id=user_id,
|
||||||
|
table_name="key",
|
||||||
|
query_type="find_all",
|
||||||
|
expires=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
if user_info is None:
|
||||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
## make sure we still return a total spend ##
|
||||||
)
|
spend = 0
|
||||||
|
for k in keys:
|
||||||
|
spend += getattr(k, "spend", 0)
|
||||||
|
user_info = {"spend": spend}
|
||||||
|
|
||||||
with google_sso:
|
## REMOVE HASHED TOKEN INFO before returning ##
|
||||||
return await google_sso.get_login_redirect()
|
for key in keys:
|
||||||
|
try:
|
||||||
# Microsoft SSO Auth
|
key = key.model_dump() # noqa
|
||||||
elif microsoft_client_id is not None:
|
except:
|
||||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
# if using pydantic v1
|
||||||
|
key = key.dict()
|
||||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
key.pop("token", None)
|
||||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
return {"user_id": user_id, "user_info": user_info, "keys": keys}
|
||||||
if microsoft_client_secret is None:
|
except Exception as e:
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||||
type="auth_error",
|
type="auth_error",
|
||||||
param="MICROSOFT_CLIENT_SECRET",
|
param=getattr(e, "param", "None"),
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
)
|
)
|
||||||
|
elif isinstance(e, ProxyException):
|
||||||
microsoft_sso = MicrosoftSSO(
|
raise e
|
||||||
client_id=microsoft_client_id,
|
raise ProxyException(
|
||||||
client_secret=microsoft_client_secret,
|
message="Authentication Error, " + str(e),
|
||||||
tenant=microsoft_tenant,
|
type="auth_error",
|
||||||
redirect_uri=redirect_url,
|
param=getattr(e, "param", "None"),
|
||||||
allow_insecure_http=True,
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
with microsoft_sso:
|
|
||||||
return await microsoft_sso.get_login_redirect()
|
|
||||||
elif generic_client_id is not None:
|
|
||||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
||||||
generic_authorization_endpoint = os.getenv(
|
|
||||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
||||||
)
|
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
||||||
if generic_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_authorization_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_token_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_TOKEN_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_userinfo_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
||||||
type="auth_error",
|
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
|
||||||
token_endpoint=generic_token_endpoint,
|
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
||||||
generic_sso = SSOProvider(
|
|
||||||
client_id=generic_client_id,
|
|
||||||
client_secret=generic_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
with generic_sso:
|
|
||||||
return await generic_sso.get_login_redirect()
|
|
||||||
|
|
||||||
elif ui_username is not None:
|
|
||||||
# No Google, Microsoft SSO
|
|
||||||
# Use UI Credentials set in .env
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
|
|
||||||
return HTMLResponse(content=html_form, status_code=200)
|
|
||||||
else:
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
|
|
||||||
return HTMLResponse(content=html_form, status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/login", include_in_schema=False
|
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
) # hidden since this is a helper for UI sso login
|
)
|
||||||
async def login(request: Request):
|
async def user_update(data: UpdateUserRequest):
|
||||||
|
"""
|
||||||
|
[TODO]: Use this to update user budget
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
import multipart
|
data_json: dict = data.json()
|
||||||
except ImportError:
|
# get the row from db
|
||||||
subprocess.run(["pip", "install", "python-multipart"])
|
if prisma_client is None:
|
||||||
global master_key
|
raise Exception("Not connected to DB!")
|
||||||
form = await request.form()
|
|
||||||
username = str(form.get("username"))
|
|
||||||
password = str(form.get("password"))
|
|
||||||
ui_username = os.getenv("UI_USERNAME", "admin")
|
|
||||||
ui_password = os.getenv("UI_PASSWORD", None)
|
|
||||||
if ui_password is None:
|
|
||||||
ui_password = str(master_key) if master_key is not None else None
|
|
||||||
|
|
||||||
if ui_password is None:
|
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
||||||
raise ProxyException(
|
response = await prisma_client.update_data(
|
||||||
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
|
user_id=data_json["user_id"],
|
||||||
type="auth_error",
|
data=non_default_values,
|
||||||
param="UI_PASSWORD",
|
update_key_values=non_default_values,
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
)
|
||||||
|
return {"user_id": data_json["user_id"], **non_default_values}
|
||||||
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
|
# update based on remaining passed in values
|
||||||
password, ui_password
|
except Exception as e:
|
||||||
):
|
traceback.print_exc()
|
||||||
user_role = "app_owner"
|
if isinstance(e, HTTPException):
|
||||||
user_id = username
|
raise ProxyException(
|
||||||
key_user_id = user_id
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||||
if (
|
type="auth_error",
|
||||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
param=getattr(e, "param", "None"),
|
||||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
) or user_id == "admin":
|
|
||||||
# checks if user is admin
|
|
||||||
user_role = "app_admin"
|
|
||||||
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
|
|
||||||
|
|
||||||
# Admin is Authe'd in - generate key for the UI to access Proxy
|
|
||||||
|
|
||||||
if os.getenv("DATABASE_URL") is not None:
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
|
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(e, ProxyException):
|
||||||
response = {
|
raise e
|
||||||
"token": "sk-gm",
|
|
||||||
"user_id": "litellm-dashboard",
|
|
||||||
}
|
|
||||||
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
|
|
||||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
jwt_token = jwt.encode(
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"key": key,
|
|
||||||
"user_email": user_id,
|
|
||||||
"user_role": user_role,
|
|
||||||
},
|
|
||||||
"secret",
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
|
||||||
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
|
||||||
|
|
||||||
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
|
||||||
general_settings["allow_user_auth"] = True
|
|
||||||
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
|
||||||
else:
|
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
message="Authentication Error, " + str(e),
|
||||||
type="auth_error",
|
type="auth_error",
|
||||||
param="invalid_credentials",
|
param=getattr(e, "param", "None"),
|
||||||
code=status.HTTP_401_UNAUTHORIZED,
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#### TEAM MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/team/new",
|
||||||
|
tags=["team management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=NewTeamResponse,
|
||||||
|
)
|
||||||
|
async def new_team(
|
||||||
|
data: NewTeamRequest,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Allow users to create a new team. Apply user permissions to their team.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- team_alias: Optional[str] - User defined team alias
|
||||||
|
- team_id: Optional[str] - The team id of the user. If none passed, we'll generate it.
|
||||||
|
- admins: list - A list of user IDs that will be owning the team
|
||||||
|
- members: list - A list of user IDs that will be members of the team
|
||||||
|
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
|
|
||||||
|
if data.team_id is None:
|
||||||
|
data.team_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
complete_team_data = LiteLLM_TeamTable(
|
||||||
|
**data.json(),
|
||||||
|
max_budget=user_api_key_dict.max_budget,
|
||||||
|
models=user_api_key_dict.models,
|
||||||
|
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||||
|
tpm_limit=user_api_key_dict.tpm_limit,
|
||||||
|
rpm_limit=user_api_key_dict.rpm_limit,
|
||||||
|
budget_duration=user_api_key_dict.budget_duration,
|
||||||
|
budget_reset_at=user_api_key_dict.budget_reset_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
team_row = await prisma_client.insert_data(
|
||||||
|
data=complete_team_data.json(exclude_none=True), table_name="team"
|
||||||
|
)
|
||||||
|
return team_row
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
|
)
|
||||||
|
async def update_team():
|
||||||
|
"""
|
||||||
|
update team and members
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
|
)
|
||||||
|
async def delete_team():
|
||||||
|
"""
|
||||||
|
delete team and team keys
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
|
)
|
||||||
|
async def team_info(
|
||||||
|
team_id: str = fastapi.Query(
|
||||||
|
default=None, description="Team ID in the request parameters"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
get info on team + related keys
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@app.get("/sso/callback", tags=["experimental"])
|
@app.get("/sso/callback", tags=["experimental"])
|
||||||
async def auth_callback(request: Request):
|
async def auth_callback(request: Request):
|
||||||
"""Verify login"""
|
"""Verify login"""
|
||||||
|
@ -4125,117 +4098,6 @@ async def auth_callback(request: Request):
|
||||||
return RedirectResponse(url=litellm_dashboard_ui)
|
return RedirectResponse(url=litellm_dashboard_ui)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
|
||||||
)
|
|
||||||
async def user_info(
|
|
||||||
user_id: Optional[str] = fastapi.Query(
|
|
||||||
default=None, description="User ID in the request parameters"
|
|
||||||
)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Use this to get user information. (user row + all user key info)
|
|
||||||
|
|
||||||
Example request
|
|
||||||
```
|
|
||||||
curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \
|
|
||||||
--header 'Authorization: Bearer sk-1234'
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
global prisma_client
|
|
||||||
try:
|
|
||||||
if prisma_client is None:
|
|
||||||
raise Exception(
|
|
||||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
|
||||||
)
|
|
||||||
## GET USER ROW ##
|
|
||||||
if user_id is not None:
|
|
||||||
user_info = await prisma_client.get_data(user_id=user_id)
|
|
||||||
else:
|
|
||||||
user_info = None
|
|
||||||
## GET ALL KEYS ##
|
|
||||||
keys = await prisma_client.get_data(
|
|
||||||
user_id=user_id,
|
|
||||||
table_name="key",
|
|
||||||
query_type="find_all",
|
|
||||||
expires=datetime.now(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_info is None:
|
|
||||||
## make sure we still return a total spend ##
|
|
||||||
spend = 0
|
|
||||||
for k in keys:
|
|
||||||
spend += getattr(k, "spend", 0)
|
|
||||||
user_info = {"spend": spend}
|
|
||||||
|
|
||||||
## REMOVE HASHED TOKEN INFO before returning ##
|
|
||||||
for key in keys:
|
|
||||||
try:
|
|
||||||
key = key.model_dump() # noqa
|
|
||||||
except:
|
|
||||||
# if using pydantic v1
|
|
||||||
key = key.dict()
|
|
||||||
key.pop("token", None)
|
|
||||||
return {"user_id": user_id, "user_info": user_info, "keys": keys}
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, HTTPException):
|
|
||||||
raise ProxyException(
|
|
||||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
||||||
type="auth_error",
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
||||||
)
|
|
||||||
elif isinstance(e, ProxyException):
|
|
||||||
raise e
|
|
||||||
raise ProxyException(
|
|
||||||
message="Authentication Error, " + str(e),
|
|
||||||
type="auth_error",
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
|
||||||
)
|
|
||||||
async def user_update(data: UpdateUserRequest):
|
|
||||||
"""
|
|
||||||
[TODO]: Use this to update user budget
|
|
||||||
"""
|
|
||||||
global prisma_client
|
|
||||||
try:
|
|
||||||
data_json: dict = data.json()
|
|
||||||
# get the row from db
|
|
||||||
if prisma_client is None:
|
|
||||||
raise Exception("Not connected to DB!")
|
|
||||||
|
|
||||||
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
|
||||||
response = await prisma_client.update_data(
|
|
||||||
user_id=data_json["user_id"],
|
|
||||||
data=non_default_values,
|
|
||||||
update_key_values=non_default_values,
|
|
||||||
)
|
|
||||||
return {"user_id": data_json["user_id"], **non_default_values}
|
|
||||||
# update based on remaining passed in values
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
if isinstance(e, HTTPException):
|
|
||||||
raise ProxyException(
|
|
||||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
||||||
type="auth_error",
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
||||||
)
|
|
||||||
elif isinstance(e, ProxyException):
|
|
||||||
raise e
|
|
||||||
raise ProxyException(
|
|
||||||
message="Authentication Error, " + str(e),
|
|
||||||
type="auth_error",
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
#### MODEL MANAGEMENT ####
|
#### MODEL MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
|
@ -4587,6 +4449,341 @@ async def retrieve_server_log(request: Request):
|
||||||
return FileResponse(filepath)
|
return FileResponse(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
#### LOGIN ENDPOINTS ####
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sso/key/generate", tags=["experimental"])
|
||||||
|
async def google_login(request: Request):
|
||||||
|
"""
|
||||||
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||||||
|
|
||||||
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||||
|
Example:
|
||||||
|
|
||||||
|
"""
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
|
||||||
|
# get url from request
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
|
||||||
|
ui_username = os.getenv("UI_USERNAME")
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/callback"
|
||||||
|
else:
|
||||||
|
redirect_url += "/sso/callback"
|
||||||
|
# Google SSO Auth
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with google_sso:
|
||||||
|
return await google_sso.get_login_redirect()
|
||||||
|
|
||||||
|
# Microsoft SSO Auth
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
with microsoft_sso:
|
||||||
|
return await microsoft_sso.get_login_redirect()
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with generic_sso:
|
||||||
|
return await generic_sso.get_login_redirect()
|
||||||
|
|
||||||
|
elif ui_username is not None:
|
||||||
|
# No Google, Microsoft SSO
|
||||||
|
# Use UI Credentials set in .env
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
elif ui_username is not None:
|
||||||
|
# No Google, Microsoft SSO
|
||||||
|
# Use UI Credentials set in .env
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
return HTMLResponse(content=html_form, status_code=200)
|
||||||
|
else:
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
return HTMLResponse(content=html_form, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/login", include_in_schema=False
|
||||||
|
) # hidden since this is a helper for UI sso login
|
||||||
|
async def login(request: Request):
|
||||||
|
try:
|
||||||
|
import multipart
|
||||||
|
except ImportError:
|
||||||
|
subprocess.run(["pip", "install", "python-multipart"])
|
||||||
|
global master_key
|
||||||
|
form = await request.form()
|
||||||
|
username = str(form.get("username"))
|
||||||
|
password = str(form.get("password"))
|
||||||
|
ui_username = os.getenv("UI_USERNAME", "admin")
|
||||||
|
ui_password = os.getenv("UI_PASSWORD", None)
|
||||||
|
if ui_password is None:
|
||||||
|
ui_password = str(master_key) if master_key is not None else None
|
||||||
|
|
||||||
|
if ui_password is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys",
|
||||||
|
type="auth_error",
|
||||||
|
param="UI_PASSWORD",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
|
||||||
|
password, ui_password
|
||||||
|
):
|
||||||
|
user_role = "app_owner"
|
||||||
|
user_id = username
|
||||||
|
key_user_id = user_id
|
||||||
|
if (
|
||||||
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||||
|
) or user_id == "admin":
|
||||||
|
# checks if user is admin
|
||||||
|
user_role = "app_admin"
|
||||||
|
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
|
||||||
|
|
||||||
|
# Admin is Authe'd in - generate key for the UI to access Proxy
|
||||||
|
|
||||||
|
if os.getenv("DATABASE_URL") is not None:
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = {
|
||||||
|
"token": "sk-gm",
|
||||||
|
"user_id": "litellm-dashboard",
|
||||||
|
}
|
||||||
|
|
||||||
|
key = response["token"] # type: ignore
|
||||||
|
|
||||||
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/"
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
jwt_token = jwt.encode(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"key": key,
|
||||||
|
"user_email": user_id,
|
||||||
|
"user_role": user_role,
|
||||||
|
},
|
||||||
|
"secret",
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
||||||
|
|
||||||
|
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
||||||
|
general_settings["allow_user_auth"] = True
|
||||||
|
return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
|
else:
|
||||||
|
raise ProxyException(
|
||||||
|
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="invalid_credentials",
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sso/callback", tags=["experimental"])
|
||||||
|
async def auth_callback(request: Request):
|
||||||
|
"""Verify login"""
|
||||||
|
global general_settings
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
|
||||||
|
# get url from request
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
|
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/callback"
|
||||||
|
else:
|
||||||
|
redirect_url += "/sso/callback"
|
||||||
|
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
)
|
||||||
|
result = await google_sso.verify_and_process(request)
|
||||||
|
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if microsoft_tenant is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_TENANT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
|
|
||||||
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
|
user_email = getattr(result, "email", None)
|
||||||
|
user_id = getattr(result, "id", None)
|
||||||
|
if user_id is None:
|
||||||
|
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||||
|
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
|
||||||
|
)
|
||||||
|
key = response["token"] # type: ignore
|
||||||
|
user_id = response["user_id"] # type: ignore
|
||||||
|
|
||||||
|
litellm_dashboard_ui = "/ui/"
|
||||||
|
|
||||||
|
user_role = "app_owner"
|
||||||
|
if (
|
||||||
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||||
|
):
|
||||||
|
# checks if user is admin
|
||||||
|
user_role = "app_admin"
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
jwt_token = jwt.encode(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"key": key,
|
||||||
|
"user_email": user_email,
|
||||||
|
"user_role": user_role,
|
||||||
|
},
|
||||||
|
"secret",
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
||||||
|
|
||||||
|
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
||||||
|
general_settings["allow_user_auth"] = True
|
||||||
|
return RedirectResponse(url=litellm_dashboard_ui)
|
||||||
|
|
||||||
|
|
||||||
#### BASIC ENDPOINTS ####
|
#### BASIC ENDPOINTS ####
|
||||||
@router.post(
|
@router.post(
|
||||||
"/config/update",
|
"/config/update",
|
||||||
|
|
|
@ -8,6 +8,25 @@ generator client {
|
||||||
provider = "prisma-client-py"
|
provider = "prisma-client-py"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign prod keys to groups, not individuals
|
||||||
|
model LiteLLM_TeamTable {
|
||||||
|
team_id String @unique
|
||||||
|
team_alias String?
|
||||||
|
admins String[]
|
||||||
|
members String[]
|
||||||
|
metadata Json @default("{}")
|
||||||
|
max_budget Float?
|
||||||
|
spend Float @default(0.0)
|
||||||
|
models String[]
|
||||||
|
max_parallel_requests Int?
|
||||||
|
tpm_limit BigInt?
|
||||||
|
rpm_limit BigInt?
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
}
|
||||||
|
|
||||||
// Track spend, rate limit, budget Users
|
// Track spend, rate limit, budget Users
|
||||||
model LiteLLM_UserTable {
|
model LiteLLM_UserTable {
|
||||||
user_id String @unique
|
user_id String @unique
|
||||||
|
|
|
@ -511,8 +511,9 @@ class PrismaClient:
|
||||||
token: Optional[Union[str, list]] = None,
|
token: Optional[Union[str, list]] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
user_id_list: Optional[list] = None,
|
user_id_list: Optional[list] = None,
|
||||||
|
team_id: Optional[str] = None,
|
||||||
key_val: Optional[dict] = None,
|
key_val: Optional[dict] = None,
|
||||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
|
||||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
expires: Optional[datetime] = None,
|
expires: Optional[datetime] = None,
|
||||||
reset_at: Optional[datetime] = None,
|
reset_at: Optional[datetime] = None,
|
||||||
|
@ -547,6 +548,14 @@ class PrismaClient:
|
||||||
for r in response:
|
for r in response:
|
||||||
if isinstance(r.expires, datetime):
|
if isinstance(r.expires, datetime):
|
||||||
r.expires = r.expires.isoformat()
|
r.expires = r.expires.isoformat()
|
||||||
|
elif query_type == "find_all" and team_id is not None:
|
||||||
|
response = await self.db.litellm_verificationtoken.find_many(
|
||||||
|
where={"team_id": team_id}
|
||||||
|
)
|
||||||
|
if response is not None and len(response) > 0:
|
||||||
|
for r in response:
|
||||||
|
if isinstance(r.expires, datetime):
|
||||||
|
r.expires = r.expires.isoformat()
|
||||||
elif (
|
elif (
|
||||||
query_type == "find_all"
|
query_type == "find_all"
|
||||||
and expires is not None
|
and expires is not None
|
||||||
|
@ -659,7 +668,23 @@ class PrismaClient:
|
||||||
order={"startTime": "desc"},
|
order={"startTime": "desc"},
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
elif table_name == "team":
|
||||||
|
if query_type == "find_unique":
|
||||||
|
response = await self.db.litellm_teamtable.find_unique(
|
||||||
|
where={"team_id": team_id} # type: ignore
|
||||||
|
)
|
||||||
|
if query_type == "find_all" and team_id is not None:
|
||||||
|
user_id_values = str(tuple(team_id))
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT *
|
||||||
|
FROM "LiteLLM_TeamTable"
|
||||||
|
WHERE "team_id" IN {team_id}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Execute the raw query
|
||||||
|
# The asterisk before `team_id` unpacks the list into separate arguments
|
||||||
|
response = await self.db.query_raw(sql_query)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -679,7 +704,7 @@ class PrismaClient:
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
async def insert_data(
|
async def insert_data(
|
||||||
self, data: dict, table_name: Literal["user", "key", "config", "spend"]
|
self, data: dict, table_name: Literal["user", "key", "config", "spend", "team"]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add a key to the database. If it already exists, do nothing.
|
Add a key to the database. If it already exists, do nothing.
|
||||||
|
@ -715,6 +740,17 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.info(f"Data Inserted into User Table")
|
verbose_proxy_logger.info(f"Data Inserted into User Table")
|
||||||
return new_user_row
|
return new_user_row
|
||||||
|
elif table_name == "team":
|
||||||
|
db_data = self.jsonify_object(data=data)
|
||||||
|
new_team_row = await self.db.litellm_teamtable.upsert(
|
||||||
|
where={"team_id": data["team_id"]},
|
||||||
|
data={
|
||||||
|
"create": {**db_data}, # type: ignore
|
||||||
|
"update": {}, # don't do anything if it already exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.info(f"Data Inserted into Team Table")
|
||||||
|
return new_team_row
|
||||||
elif table_name == "config":
|
elif table_name == "config":
|
||||||
"""
|
"""
|
||||||
For each param,
|
For each param,
|
||||||
|
|
|
@ -8,6 +8,25 @@ generator client {
|
||||||
provider = "prisma-client-py"
|
provider = "prisma-client-py"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign prod keys to groups, not individuals
|
||||||
|
model LiteLLM_TeamTable {
|
||||||
|
team_id String @unique
|
||||||
|
team_alias String?
|
||||||
|
admins String[]
|
||||||
|
members String[]
|
||||||
|
metadata Json @default("{}")
|
||||||
|
max_budget Float?
|
||||||
|
spend Float @default(0.0)
|
||||||
|
models String[]
|
||||||
|
max_parallel_requests Int?
|
||||||
|
tpm_limit BigInt?
|
||||||
|
rpm_limit BigInt?
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
}
|
||||||
|
|
||||||
// Track spend, rate limit, budget Users
|
// Track spend, rate limit, budget Users
|
||||||
model LiteLLM_UserTable {
|
model LiteLLM_UserTable {
|
||||||
user_id String @unique
|
user_id String @unique
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue