Merge pull request #2300 from BerriAI/litellm_organization_table

feat(proxy_server.py): enable `/organizations/new`, `/organization/info` and `/budget/info` endpoints
This commit is contained in:
Krish Dholakia 2024-03-02 18:37:36 -08:00 committed by GitHub
commit ce84bce8e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 309 additions and 12 deletions

View file

@ -251,6 +251,7 @@ class Member(LiteLLMBase):
class NewTeamRequest(LiteLLMBase): class NewTeamRequest(LiteLLMBase):
team_alias: Optional[str] = None team_alias: Optional[str] = None
team_id: Optional[str] = None team_id: Optional[str] = None
organization_id: Optional[str] = None
admins: list = [] admins: list = []
members: list = [] members: list = []
members_with_roles: List[Member] = [] members_with_roles: List[Member] = []
@ -327,19 +328,46 @@ class TeamRequest(LiteLLMBase):
class LiteLLM_BudgetTable(LiteLLMBase): class LiteLLM_BudgetTable(LiteLLMBase):
"""Represents user-controllable params for a LiteLLM_BudgetTable record""" """Represents user-controllable params for a LiteLLM_BudgetTable record"""
max_budget: Optional[float] = None
soft_budget: Optional[float] = None soft_budget: Optional[float] = None
max_budget: Optional[float] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
model_max_budget: dict model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
class NewOrganizationRequest(LiteLLM_BudgetTable):
organization_alias: str
models: List = []
budget_id: Optional[str] = None
class LiteLLM_OrganizationTable(LiteLLMBase):
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
organization_alias: Optional[str] = None
budget_id: str
metadata: Optional[dict] = None
models: List[str]
created_by: str created_by: str
updated_by: str updated_by: str
class NewOrganizationResponse(LiteLLM_OrganizationTable):
organization_id: str
created_at: datetime
updated_at: datetime
class OrganizationRequest(LiteLLMBase):
organizations: List[str]
class BudgetRequest(LiteLLMBase):
budgets: List[str]
class KeyManagementSystem(enum.Enum): class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms" GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault" AZURE_KEY_VAULT = "azure_key_vault"

View file

@ -239,6 +239,7 @@ health_check_interval = None
health_check_results = {} health_check_results = {}
queue: List = [] queue: List = []
litellm_proxy_budget_name = "litellm-proxy-budget" litellm_proxy_budget_name = "litellm-proxy-budget"
litellm_proxy_admin_name = "default_user_id"
ui_access_mode: Literal["admin", "all"] = "all" ui_access_mode: Literal["admin", "all"] = "all"
proxy_budget_rescheduler_min_time = 597 proxy_budget_rescheduler_min_time = 597
proxy_budget_rescheduler_max_time = 605 proxy_budget_rescheduler_max_time = 605
@ -335,7 +336,11 @@ async def user_api_key_auth(
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid: if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key, user_role="proxy_admin") return UserAPIKeyAuth(
api_key=master_key,
user_role="proxy_admin",
user_id=litellm_proxy_admin_name,
)
if isinstance( if isinstance(
api_key, str api_key, str
): # if generated token, make sure it starts with sk-. ): # if generated token, make sure it starts with sk-.
@ -360,7 +365,6 @@ async def user_api_key_auth(
valid_token = await prisma_client.get_data( valid_token = await prisma_client.get_data(
token=api_key, table_name="combined_view" token=api_key, table_name="combined_view"
) )
elif custom_db_client is not None: elif custom_db_client is not None:
try: try:
valid_token = await custom_db_client.get_data( valid_token = await custom_db_client.get_data(
@ -2229,7 +2233,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -2276,9 +2280,8 @@ async def startup_event():
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
user_id = "default_user_id"
if os.getenv("PROXY_ADMIN_ID", None) is not None: if os.getenv("PROXY_ADMIN_ID", None) is not None:
user_id = os.getenv("PROXY_ADMIN_ID") litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
asyncio.create_task( asyncio.create_task(
generate_key_helper_fn( generate_key_helper_fn(
@ -2288,7 +2291,7 @@ async def startup_event():
config={}, config={},
spend=0, spend=0,
token=master_key, token=master_key,
user_id=user_id, user_id=litellm_proxy_admin_name,
user_role="proxy_admin", user_role="proxy_admin",
query_type="update_data", query_type="update_data",
update_key_values={ update_key_values={
@ -5396,6 +5399,226 @@ async def team_info(
) )
#### ORGANIZATION MANAGEMENT ####
@router.post(
"/organization/new",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewOrganizationResponse,
)
async def new_organization(
data: NewOrganizationRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Allow orgs to own teams
Set org level budgets + model access.
Only admins can create orgs.
# Parameters
- `organization_alias`: *str* = The name of the organization.
- `models`: *List* = The models the organization has access to.
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
- `max_budget`: *Optional[float]* = Max budget for org
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
- `model_max_budget`: *Optional[dict]* = Max budget for a specific model
- `budget_duration`: *Optional[str]* = Frequency of reseting org budget
Case 1: Create new org **without** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/organization/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"organization_alias": "my-secret-org",
"models": ["model1", "model2"],
"max_budget": 100
}'
```
Case 2: Create new org **with** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/organization/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"organization_alias": "my-secret-org",
"models": ["model1", "model2"],
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689"
}'
```
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if (
user_api_key_dict.user_role is None
or user_api_key_dict.user_role != "proxy_admin"
):
raise HTTPException(
status_code=401,
detail={
"error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}"
},
)
if data.budget_id is None:
"""
Every organization needs a budget attached.
If none provided, create one based on provided values
"""
budget_row = LiteLLM_BudgetTable(**data.json(exclude_none=True))
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
) # type: ignore
data.budget_id = _budget.budget_id
"""
Ensure only models that user has access to, are given to org
"""
if len(user_api_key_dict.models) == 0: # user has access to all models
pass
else:
if len(data.models) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"User not allowed to give access to all models. Select models you want org to have access to."
},
)
for m in data.models:
if m not in user_api_key_dict.models:
raise HTTPException(
status_code=400,
detail={
"error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}"
},
)
organization_row = LiteLLM_OrganizationTable(
**data.json(exclude_none=True),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
new_organization_row = prisma_client.jsonify_object(
organization_row.json(exclude_none=True)
)
response = await prisma_client.db.litellm_organizationtable.create(
data={
**new_organization_row, # type: ignore
}
)
return response
@router.post(
"/organization/update",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_organization():
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
pass
@router.post(
"/organization/delete",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_organization():
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
pass
@router.post(
"/organization/info",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_organization(data: OrganizationRequest):
"""
Get the org specific information
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.organizations) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of organization id's to query. Passed in={data.organizations}"
},
)
response = await prisma_client.db.litellm_organizationtable.find_many(
where={"organization_id": {"in": data.organizations}},
include={"litellm_budget_table": True},
)
return response
#### BUDGET TABLE MANAGEMENT ####
@router.post(
"/budget/info",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_budget(data: BudgetRequest):
"""
Get the budget id specific information
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.budgets) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
},
)
response = await prisma_client.db.litellm_budgettable.find_many(
where={"budget_id": {"in": data.budgets}},
)
return response
#### MODEL MANAGEMENT #### #### MODEL MANAGEMENT ####

View file

@ -15,7 +15,7 @@ model LiteLLM_BudgetTable {
max_parallel_requests Int? max_parallel_requests Int?
tpm_limit BigInt? tpm_limit BigInt?
rpm_limit BigInt? rpm_limit BigInt?
model_max_budget Json @default("{}") model_max_budget Json?
budget_duration String? budget_duration String?
budget_reset_at DateTime? budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at") created_at DateTime @default(now()) @map("created_at")

View file

@ -15,7 +15,7 @@ model LiteLLM_BudgetTable {
max_parallel_requests Int? max_parallel_requests Int?
tpm_limit BigInt? tpm_limit BigInt?
rpm_limit BigInt? rpm_limit BigInt?
model_max_budget Json @default("{}") model_max_budget Json?
budget_duration String? budget_duration String?
budget_reset_at DateTime? budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at") created_at DateTime @default(now()) @map("created_at")

View file

@ -0,0 +1,46 @@
# What this tests ?
## Tests /organization endpoints.
import pytest
import asyncio
import aiohttp
import time, uuid
from openai import AsyncOpenAI
async def new_organization(session, i, organization_alias, max_budget=None):
url = "http://0.0.0.0:4000/organization/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"organization_alias": organization_alias,
"models": ["azure-models"],
"max_budget": max_budget,
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_organization_new():
"""
Make 20 parallel calls to /user/new. Assert all worked.
"""
organization_alias = f"Organization: {uuid.uuid4()}"
async with aiohttp.ClientSession() as session:
tasks = [
new_organization(
session=session, i=0, organization_alias=organization_alias
)
for i in range(1, 20)
]
await asyncio.gather(*tasks)