forked from phoenix/litellm-mirror
Merge pull request #2300 from BerriAI/litellm_organization_table
feat(proxy_server.py): enable `/organizations/new`, `/organization/info` and `/budget/info` endpoints
This commit is contained in:
commit
ce84bce8e6
5 changed files with 309 additions and 12 deletions
|
@ -251,6 +251,7 @@ class Member(LiteLLMBase):
|
|||
class NewTeamRequest(LiteLLMBase):
|
||||
team_alias: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
admins: list = []
|
||||
members: list = []
|
||||
members_with_roles: List[Member] = []
|
||||
|
@ -327,19 +328,46 @@ class TeamRequest(LiteLLMBase):
|
|||
|
||||
class LiteLLM_BudgetTable(LiteLLMBase):
|
||||
"""Represents user-controllable params for a LiteLLM_BudgetTable record"""
|
||||
|
||||
max_budget: Optional[float] = None
|
||||
soft_budget: Optional[float] = None
|
||||
max_budget: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
model_max_budget: dict
|
||||
model_max_budget: Optional[dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||
organization_alias: str
|
||||
models: List = []
|
||||
budget_id: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_OrganizationTable(LiteLLMBase):
|
||||
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
||||
|
||||
organization_alias: Optional[str] = None
|
||||
budget_id: str
|
||||
metadata: Optional[dict] = None
|
||||
models: List[str]
|
||||
created_by: str
|
||||
updated_by: str
|
||||
|
||||
|
||||
class NewOrganizationResponse(LiteLLM_OrganizationTable):
|
||||
organization_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class OrganizationRequest(LiteLLMBase):
|
||||
organizations: List[str]
|
||||
|
||||
|
||||
class BudgetRequest(LiteLLMBase):
|
||||
budgets: List[str]
|
||||
|
||||
|
||||
class KeyManagementSystem(enum.Enum):
|
||||
GOOGLE_KMS = "google_kms"
|
||||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
|
|
|
@ -239,6 +239,7 @@ health_check_interval = None
|
|||
health_check_results = {}
|
||||
queue: List = []
|
||||
litellm_proxy_budget_name = "litellm-proxy-budget"
|
||||
litellm_proxy_admin_name = "default_user_id"
|
||||
ui_access_mode: Literal["admin", "all"] = "all"
|
||||
proxy_budget_rescheduler_min_time = 597
|
||||
proxy_budget_rescheduler_max_time = 605
|
||||
|
@ -335,7 +336,11 @@ async def user_api_key_auth(
|
|||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||
if is_master_key_valid:
|
||||
return UserAPIKeyAuth(api_key=master_key, user_role="proxy_admin")
|
||||
return UserAPIKeyAuth(
|
||||
api_key=master_key,
|
||||
user_role="proxy_admin",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
if isinstance(
|
||||
api_key, str
|
||||
): # if generated token, make sure it starts with sk-.
|
||||
|
@ -360,7 +365,6 @@ async def user_api_key_auth(
|
|||
valid_token = await prisma_client.get_data(
|
||||
token=api_key, table_name="combined_view"
|
||||
)
|
||||
|
||||
elif custom_db_client is not None:
|
||||
try:
|
||||
valid_token = await custom_db_client.get_data(
|
||||
|
@ -2229,7 +2233,7 @@ def parse_cache_control(cache_control):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
|
||||
import json
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
|
@ -2276,9 +2280,8 @@ async def startup_event():
|
|||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
user_id = "default_user_id"
|
||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||
user_id = os.getenv("PROXY_ADMIN_ID")
|
||||
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
|
||||
|
||||
asyncio.create_task(
|
||||
generate_key_helper_fn(
|
||||
|
@ -2288,7 +2291,7 @@ async def startup_event():
|
|||
config={},
|
||||
spend=0,
|
||||
token=master_key,
|
||||
user_id=user_id,
|
||||
user_id=litellm_proxy_admin_name,
|
||||
user_role="proxy_admin",
|
||||
query_type="update_data",
|
||||
update_key_values={
|
||||
|
@ -5396,6 +5399,226 @@ async def team_info(
|
|||
)
|
||||
|
||||
|
||||
#### ORGANIZATION MANAGEMENT ####
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/new",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=NewOrganizationResponse,
|
||||
)
|
||||
async def new_organization(
|
||||
data: NewOrganizationRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Allow orgs to own teams
|
||||
|
||||
Set org level budgets + model access.
|
||||
|
||||
Only admins can create orgs.
|
||||
|
||||
# Parameters
|
||||
|
||||
- `organization_alias`: *str* = The name of the organization.
|
||||
- `models`: *List* = The models the organization has access to.
|
||||
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
|
||||
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
|
||||
- `max_budget`: *Optional[float]* = Max budget for org
|
||||
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
|
||||
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
|
||||
- `model_max_budget`: *Optional[dict]* = Max budget for a specific model
|
||||
- `budget_duration`: *Optional[str]* = Frequency of reseting org budget
|
||||
|
||||
Case 1: Create new org **without** a budget_id
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/organization/new' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data '{
|
||||
"organization_alias": "my-secret-org",
|
||||
"models": ["model1", "model2"],
|
||||
"max_budget": 100
|
||||
}'
|
||||
|
||||
|
||||
```
|
||||
|
||||
Case 2: Create new org **with** a budget_id
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/organization/new' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data '{
|
||||
"organization_alias": "my-secret-org",
|
||||
"models": ["model1", "model2"],
|
||||
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role is None
|
||||
or user_api_key_dict.user_role != "proxy_admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
if data.budget_id is None:
|
||||
"""
|
||||
Every organization needs a budget attached.
|
||||
|
||||
If none provided, create one based on provided values
|
||||
"""
|
||||
budget_row = LiteLLM_BudgetTable(**data.json(exclude_none=True))
|
||||
|
||||
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
) # type: ignore
|
||||
|
||||
data.budget_id = _budget.budget_id
|
||||
|
||||
"""
|
||||
Ensure only models that user has access to, are given to org
|
||||
"""
|
||||
if len(user_api_key_dict.models) == 0: # user has access to all models
|
||||
pass
|
||||
else:
|
||||
if len(data.models) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"User not allowed to give access to all models. Select models you want org to have access to."
|
||||
},
|
||||
)
|
||||
for m in data.models:
|
||||
if m not in user_api_key_dict.models:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}"
|
||||
},
|
||||
)
|
||||
organization_row = LiteLLM_OrganizationTable(
|
||||
**data.json(exclude_none=True),
|
||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
new_organization_row = prisma_client.jsonify_object(
|
||||
organization_row.json(exclude_none=True)
|
||||
)
|
||||
response = await prisma_client.db.litellm_organizationtable.create(
|
||||
data={
|
||||
**new_organization_row, # type: ignore
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/update",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_organization():
|
||||
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
|
||||
pass
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/delete",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_organization():
|
||||
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
|
||||
pass
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/info",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_organization(data: OrganizationRequest):
|
||||
"""
|
||||
Get the org specific information
|
||||
"""
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if len(data.organizations) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Specify list of organization id's to query. Passed in={data.organizations}"
|
||||
},
|
||||
)
|
||||
response = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
where={"organization_id": {"in": data.organizations}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
#### BUDGET TABLE MANAGEMENT ####
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/info",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_budget(data: BudgetRequest):
|
||||
"""
|
||||
Get the budget id specific information
|
||||
"""
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if len(data.budgets) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
|
||||
},
|
||||
)
|
||||
response = await prisma_client.db.litellm_budgettable.find_many(
|
||||
where={"budget_id": {"in": data.budgets}},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
#### MODEL MANAGEMENT ####
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ model LiteLLM_BudgetTable {
|
|||
max_parallel_requests Int?
|
||||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
model_max_budget Json @default("{}")
|
||||
model_max_budget Json?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
|
|
|
@ -15,7 +15,7 @@ model LiteLLM_BudgetTable {
|
|||
max_parallel_requests Int?
|
||||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
model_max_budget Json @default("{}")
|
||||
model_max_budget Json?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
|
|
46
tests/test_organizations.py
Normal file
46
tests/test_organizations.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# What this tests ?
|
||||
## Tests /organization endpoints.
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time, uuid
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def new_organization(session, i, organization_alias, max_budget=None):
|
||||
url = "http://0.0.0.0:4000/organization/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"organization_alias": organization_alias,
|
||||
"models": ["azure-models"],
|
||||
"max_budget": max_budget,
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
||||
print(f"Response {i} (Status code: {status}):")
|
||||
print(response_text)
|
||||
print()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request {i} did not return a 200 status code: {status}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_new():
|
||||
"""
|
||||
Make 20 parallel calls to /user/new. Assert all worked.
|
||||
"""
|
||||
organization_alias = f"Organization: {uuid.uuid4()}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [
|
||||
new_organization(
|
||||
session=session, i=0, organization_alias=organization_alias
|
||||
)
|
||||
for i in range(1, 20)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
Loading…
Add table
Add a link
Reference in a new issue