fix(proxy_server.py): enable admin to create new budget if none set for org

This commit is contained in:
Krrish Dholakia 2024-03-02 14:38:42 -08:00
parent 6fb19c5d42
commit 8d22ed762e
3 changed files with 134 additions and 39 deletions

View file

@ -324,15 +324,6 @@ class TeamRequest(LiteLLMBase):
teams: List[str]
class NewOrganizationRequest(LiteLLMBase):
organization_alias: Optional[str] = None
models: List = []
budget_id: Optional[str] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
max_budget: Optional[float] = None
class LiteLLM_BudgetTable(LiteLLMBase):
"""Represents user-controllable params for a LiteLLM_BudgetTable record"""
@ -340,28 +331,33 @@ class LiteLLM_BudgetTable(LiteLLMBase):
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
model_max_budget: dict
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
created_by: str
updated_by: str
class NewOrganizationRequest(LiteLLM_BudgetTable):
organization_alias: str
models: List = []
budget_id: Optional[str] = None
class LiteLLM_OrganizationTable(LiteLLMBase):
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
organization_id: str
organization_alias: Optional[str] = None
budget_id: str
metadata: dict
metadata: Optional[dict] = None
models: List[str]
spend: float
model_spend: dict
created_at: datetime
created_by: str
updated_at: datetime
updated_by: str
class NewOrganizationResponse(LiteLLM_OrganizationTable):
organization_id: str
created_at: datetime
updated_at: datetime
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"

View file

@ -5422,46 +5422,133 @@ async def team_info(
"/organization/new",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_OrganizationTable,
response_model=NewOrganizationResponse,
)
async def new_organization(
data: NewOrganizationRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Allow orgs to own teams
Set org level budgets + model access.
Only admins can create orgs.
# Parameters
- `organization_alias`: *str* = The name of the organization.
- `models`: *List* = The models the organization has access to.
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
- `max_budget`: *Optional[float]* = Max budget for org
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
- `model_max_budget`: *Optional[dict]* = Max budget for a specific model
- `budget_duration`: *Optional[str]* = Frequency of reseting org budget
Case 1: Create new org **without** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/organization/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"organization_alias": "my-secret-org",
"models": ["model1", "model2"],
"max_budget": 100
}'
```
Case 2: Create new org **with** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/organization/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"organization_alias": "my-secret-org",
"models": ["model1", "model2"],
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689"
}'
```
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if (
user_api_key_dict.user_role is None
or user_api_key_dict.user_role != "proxy_admin"
):
raise HTTPException(
status_code=401,
detail={
"error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}"
},
)
if data.budget_id is None:
"""
Every organization needs a budget attached.
If none provided, create one based on user max
If none provided, create one based on provided values
"""
budget_row = LiteLLM_BudgetTable(
max_budget=user_api_key_dict.max_budget,
max_parallel_requests=user_api_key_dict.max_parallel_requests,
model_max_budget=user_api_key_dict.model_max_budget,
tpm_limit=user_api_key_dict.tpm_limit,
rpm_limit=user_api_key_dict.rpm_limit,
budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at,
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
budget_row = LiteLLM_BudgetTable(**data.json(exclude_none=True))
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(data={**new_budget}) # type: ignore
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
) # type: ignore
data.budget_id = _budget.budget_id
"""
Ensure only models that user has access to, are given to org
"""
if len(user_api_key_dict.models) == 0: # user has access to all models
pass
else:
if len(data.models) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"User not allowed to give access to all models. Select models you want org to have access to."
},
)
for m in data.models:
if m not in user_api_key_dict.models:
raise HTTPException(
status_code=400,
detail={
"error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}"
},
)
organization_row = LiteLLM_OrganizationTable(
**data.json(exclude_none=True),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
new_organization_row = prisma_client.jsonify_object(
organization_row.json(exclude_none=True)
)
response = await prisma_client.db.litellm_organizationtable.create(
data={
**data.json(exclude_none=True), # type: ignore
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
**new_organization_row, # type: ignore
}
)
@ -5472,9 +5559,9 @@ async def new_organization(
"/organization/update",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_TeamTable,
)
async def update_organization():
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
pass
@ -5482,9 +5569,21 @@ async def update_organization():
"/organization/delete",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_TeamTable,
)
async def delete_organization():
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
pass
@router.post(
"/organization/info",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_organization():
"""
Get the org specific information
"""
pass

View file

@ -14,7 +14,7 @@ model LiteLLM_BudgetTable {
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
model_max_budget Json @default("{}")
model_max_budget Json?
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")