feat(proxy_server.py): enable /organizations/new endpoint

allows admins to create organizations which can own teams
This commit is contained in:
Krrish Dholakia 2024-03-02 11:55:16 -08:00
parent 468995b288
commit 2602102ce6
4 changed files with 202 additions and 17 deletions

View file

@ -324,6 +324,44 @@ class TeamRequest(LiteLLMBase):
teams: List[str] teams: List[str]
class NewOrganizationRequest(LiteLLMBase):
organization_alias: Optional[str] = None
models: List = []
budget_id: Optional[str] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
max_budget: Optional[float] = None
class LiteLLM_BudgetTable(LiteLLMBase):
"""Represents user-controllable params for a LiteLLM_BudgetTable record"""
max_budget: Optional[float] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
model_max_budget: dict
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
created_by: str
updated_by: str
class LiteLLM_OrganizationTable(LiteLLMBase):
organization_id: str
organization_alias: Optional[str] = None
budget_id: str
metadata: dict
models: List[str]
spend: float
model_spend: dict
created_at: datetime
created_by: str
updated_at: datetime
updated_by: str
class KeyManagementSystem(enum.Enum): class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms" GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault" AZURE_KEY_VAULT = "azure_key_vault"

View file

@ -239,6 +239,7 @@ health_check_interval = None
health_check_results = {} health_check_results = {}
queue: List = [] queue: List = []
litellm_proxy_budget_name = "litellm-proxy-budget" litellm_proxy_budget_name = "litellm-proxy-budget"
litellm_proxy_admin_name = "default_user_id"
ui_access_mode: Literal["admin", "all"] = "all" ui_access_mode: Literal["admin", "all"] = "all"
proxy_budget_rescheduler_min_time = 597 proxy_budget_rescheduler_min_time = 597
proxy_budget_rescheduler_max_time = 605 proxy_budget_rescheduler_max_time = 605
@ -335,7 +336,11 @@ async def user_api_key_auth(
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid: if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key, user_role="proxy_admin") return UserAPIKeyAuth(
api_key=master_key,
user_role="proxy_admin",
user_id=litellm_proxy_admin_name,
)
if isinstance( if isinstance(
api_key, str api_key, str
): # if generated token, make sure it starts with sk-. ): # if generated token, make sure it starts with sk-.
@ -360,7 +365,6 @@ async def user_api_key_auth(
valid_token = await prisma_client.get_data( valid_token = await prisma_client.get_data(
token=api_key, table_name="combined_view" token=api_key, table_name="combined_view"
) )
elif custom_db_client is not None: elif custom_db_client is not None:
try: try:
valid_token = await custom_db_client.get_data( valid_token = await custom_db_client.get_data(
@ -2213,7 +2217,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -2260,9 +2264,8 @@ async def startup_event():
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
user_id = "default_user_id"
if os.getenv("PROXY_ADMIN_ID", None) is not None: if os.getenv("PROXY_ADMIN_ID", None) is not None:
user_id = os.getenv("PROXY_ADMIN_ID") litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
asyncio.create_task( asyncio.create_task(
generate_key_helper_fn( generate_key_helper_fn(
@ -2272,7 +2275,7 @@ async def startup_event():
config={}, config={},
spend=0, spend=0,
token=master_key, token=master_key,
user_id=user_id, user_id=litellm_proxy_admin_name,
user_role="proxy_admin", user_role="proxy_admin",
query_type="update_data", query_type="update_data",
update_key_values={ update_key_values={
@ -5412,6 +5415,81 @@ async def team_info(
) )
#### ORGANIZATION MANAGEMENT ####
@router.post(
"/organization/new",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_OrganizationTable,
)
async def new_organization(
data: NewOrganizationRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.budget_id is None:
"""
Every organization needs a budget attached.
If none provided, create one based on user max
"""
budget_row = LiteLLM_BudgetTable(
max_budget=user_api_key_dict.max_budget,
max_parallel_requests=user_api_key_dict.max_parallel_requests,
model_max_budget=user_api_key_dict.model_max_budget,
tpm_limit=user_api_key_dict.tpm_limit,
rpm_limit=user_api_key_dict.rpm_limit,
budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at,
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
new_budget = prisma_client.jsonify_object(
budget_row.model_dump(exclude_none=True)
)
_budget = await prisma_client.db.litellm_budgettable.create(data={**new_budget}) # type: ignore
data.budget_id = _budget.budget_id
response = await prisma_client.db.litellm_organizationtable.create(
data={
**data.model_dump(exclude_none=True), # type: ignore
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
}
)
return response
@router.post(
"/organization/update",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_TeamTable,
)
async def update_organization():
pass
@router.post(
"/organization/delete",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_TeamTable,
)
async def delete_organization():
pass
#### MODEL MANAGEMENT #### #### MODEL MANAGEMENT ####

View file

@ -7,10 +7,44 @@ generator client {
provider = "prisma-client-py" provider = "prisma-client-py"
} }
// Budget / Rate Limits for an org
model LiteLLM_BudgetTable {
budget_id String @id @default(uuid())
max_budget Float?
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
model_max_budget Json @default("{}")
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
organization LiteLLM_OrganizationTable[]
}
model LiteLLM_OrganizationTable {
organization_id String @id @default(uuid())
organization_alias String?
budget_id String
metadata Json @default("{}")
models String[]
spend Float @default(0.0)
model_spend Json @default("{}")
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
teams LiteLLM_TeamTable[]
}
// Assign prod keys to groups, not individuals // Assign prod keys to groups, not individuals
model LiteLLM_TeamTable { model LiteLLM_TeamTable {
team_id String @unique team_id String @id @default(uuid())
team_alias String? team_alias String?
organization_id String?
admins String[] admins String[]
members String[] members String[]
members_with_roles Json @default("{}") members_with_roles Json @default("{}")
@ -27,11 +61,12 @@ model LiteLLM_TeamTable {
updated_at DateTime @default(now()) @updatedAt @map("updated_at") updated_at DateTime @default(now()) @updatedAt @map("updated_at")
model_spend Json @default("{}") model_spend Json @default("{}")
model_max_budget Json @default("{}") model_max_budget Json @default("{}")
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
} }
// Track spend, rate limit, budget Users // Track spend, rate limit, budget Users
model LiteLLM_UserTable { model LiteLLM_UserTable {
user_id String @unique user_id String @id
team_id String? team_id String?
teams String[] @default([]) teams String[] @default([])
user_role String? user_role String?
@ -51,7 +86,7 @@ model LiteLLM_UserTable {
// Generate Tokens for Proxy // Generate Tokens for Proxy
model LiteLLM_VerificationToken { model LiteLLM_VerificationToken {
token String @unique token String @id
key_name String? key_name String?
key_alias String? key_alias String?
spend Float @default(0.0) spend Float @default(0.0)
@ -82,7 +117,7 @@ model LiteLLM_Config {
// View spend, model, api_key per request // View spend, model, api_key per request
model LiteLLM_SpendLogs { model LiteLLM_SpendLogs {
request_id String @unique request_id String @id
call_type String call_type String
api_key String @default ("") api_key String @default ("")
spend Float @default(0.0) spend Float @default(0.0)
@ -100,9 +135,10 @@ model LiteLLM_SpendLogs {
team_id String? team_id String?
end_user String? end_user String?
} }
// Beta - allow team members to request access to a model // Beta - allow team members to request access to a model
model LiteLLM_UserNotifications { model LiteLLM_UserNotifications {
request_id String @unique request_id String @id
user_id String user_id String
models String[] models String[]
justification String justification String

View file

@ -7,10 +7,42 @@ generator client {
provider = "prisma-client-py" provider = "prisma-client-py"
} }
// Budget / Rate Limits for an org
model LiteLLM_BudgetTable {
budget_id String @id @default(uuid())
max_budget Float?
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
model_max_budget Json @default("{}")
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
}
model LiteLLM_OrganizationTable {
organization_id String @id @default(uuid())
organization_alias String?
budget_id String
metadata Json @default("{}")
models String[]
spend Float @default(0.0)
model_spend Json @default("{}")
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
litellm_budget_table LiteLLM_BudgetTable @relation(fields: [budget_id], references: [budget_id])
}
// Assign prod keys to groups, not individuals // Assign prod keys to groups, not individuals
model LiteLLM_TeamTable { model LiteLLM_TeamTable {
team_id String @unique team_id String @id @default(uuid())
team_alias String? team_alias String?
organization_id String?
admins String[] admins String[]
members String[] members String[]
members_with_roles Json @default("{}") members_with_roles Json @default("{}")
@ -27,11 +59,12 @@ model LiteLLM_TeamTable {
updated_at DateTime @default(now()) @updatedAt @map("updated_at") updated_at DateTime @default(now()) @updatedAt @map("updated_at")
model_spend Json @default("{}") model_spend Json @default("{}")
model_max_budget Json @default("{}") model_max_budget Json @default("{}")
litellm_organization_table LiteLLM_OrganizationTable @relation(fields: [organization_id], references: [organization_id])
} }
// Track spend, rate limit, budget Users // Track spend, rate limit, budget Users
model LiteLLM_UserTable { model LiteLLM_UserTable {
user_id String @unique user_id String @id
team_id String? team_id String?
teams String[] @default([]) teams String[] @default([])
user_role String? user_role String?
@ -51,7 +84,7 @@ model LiteLLM_UserTable {
// Generate Tokens for Proxy // Generate Tokens for Proxy
model LiteLLM_VerificationToken { model LiteLLM_VerificationToken {
token String @unique token String @id
key_name String? key_name String?
key_alias String? key_alias String?
spend Float @default(0.0) spend Float @default(0.0)
@ -82,7 +115,7 @@ model LiteLLM_Config {
// View spend, model, api_key per request // View spend, model, api_key per request
model LiteLLM_SpendLogs { model LiteLLM_SpendLogs {
request_id String @unique request_id String @id
call_type String call_type String
api_key String @default ("") api_key String @default ("")
spend Float @default(0.0) spend Float @default(0.0)
@ -103,7 +136,7 @@ model LiteLLM_SpendLogs {
// Beta - allow team members to request access to a model // Beta - allow team members to request access to a model
model LiteLLM_UserNotifications { model LiteLLM_UserNotifications {
request_id String @unique request_id String @id
user_id String user_id String
models String[] models String[]
justification String justification String