Merge branch 'main' into litellm_slack_budget_alerting

This commit is contained in:
Krish Dholakia 2024-03-02 19:13:57 -08:00 committed by GitHub
commit cbeb65a442
35 changed files with 520 additions and 67 deletions

View file

@ -241,6 +241,7 @@ health_check_interval = None
health_check_results = {}
queue: List = []
litellm_proxy_budget_name = "litellm-proxy-budget"
litellm_proxy_admin_name = "default_user_id"
ui_access_mode: Literal["admin", "all"] = "all"
proxy_budget_rescheduler_min_time = 597
proxy_budget_rescheduler_max_time = 605
@ -337,7 +338,11 @@ async def user_api_key_auth(
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key, user_role="proxy_admin")
return UserAPIKeyAuth(
api_key=master_key,
user_role="proxy_admin",
user_id=litellm_proxy_admin_name,
)
if isinstance(
api_key, str
): # if generated token, make sure it starts with sk-.
@ -1853,6 +1858,9 @@ async def generate_key_helper_fn(
key_soft_budget: Optional[
float
] = None, # key_soft_budget is used to Budget Per key
soft_budget: Optional[
float
] = None, # soft_budget is used to set soft Budgets Per user
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
token: Optional[str] = None,
@ -1919,13 +1927,18 @@ async def generate_key_helper_fn(
budget_row = LiteLLM_BudgetTable(
soft_budget=key_soft_budget,
model_max_budget=model_max_budget or {},
created_by=user_id,
updated_by=user_id,
)
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(data={**new_budget}) # type: ignore
_budget_id = getattr(_budget, "budget_id", None)
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget, # type: ignore
"created_by": user_id,
"updated_by": user_id,
}
)
_budget_id = getattr(_budget, "budget_id", None)
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {
@ -2266,7 +2279,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
import json
### LOAD MASTER KEY ###
@ -2313,9 +2326,8 @@ async def startup_event():
if prisma_client is not None and master_key is not None:
# add master key to db
user_id = "default_user_id"
if os.getenv("PROXY_ADMIN_ID", None) is not None:
user_id = os.getenv("PROXY_ADMIN_ID")
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
asyncio.create_task(
generate_key_helper_fn(
@ -2325,7 +2337,7 @@ async def startup_event():
config={},
spend=0,
token=master_key,
user_id=user_id,
user_id=litellm_proxy_admin_name,
user_role="proxy_admin",
query_type="update_data",
update_key_values={
@ -5433,6 +5445,226 @@ async def team_info(
)
#### ORGANIZATION MANAGEMENT ####
@router.post(
"/organization/new",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewOrganizationResponse,
)
async def new_organization(
data: NewOrganizationRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Allow orgs to own teams
Set org level budgets + model access.
Only admins can create orgs.
# Parameters
- `organization_alias`: *str* = The name of the organization.
- `models`: *List* = The models the organization has access to.
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
- `max_budget`: *Optional[float]* = Max budget for org
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
- `model_max_budget`: *Optional[dict]* = Max budget for a specific model
- `budget_duration`: *Optional[str]* = Frequency of reseting org budget
Case 1: Create new org **without** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/organization/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"organization_alias": "my-secret-org",
"models": ["model1", "model2"],
"max_budget": 100
}'
```
Case 2: Create new org **with** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/organization/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"organization_alias": "my-secret-org",
"models": ["model1", "model2"],
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689"
}'
```
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if (
user_api_key_dict.user_role is None
or user_api_key_dict.user_role != "proxy_admin"
):
raise HTTPException(
status_code=401,
detail={
"error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}"
},
)
if data.budget_id is None:
"""
Every organization needs a budget attached.
If none provided, create one based on provided values
"""
budget_row = LiteLLM_BudgetTable(**data.json(exclude_none=True))
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
) # type: ignore
data.budget_id = _budget.budget_id
"""
Ensure only models that user has access to, are given to org
"""
if len(user_api_key_dict.models) == 0: # user has access to all models
pass
else:
if len(data.models) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"User not allowed to give access to all models. Select models you want org to have access to."
},
)
for m in data.models:
if m not in user_api_key_dict.models:
raise HTTPException(
status_code=400,
detail={
"error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}"
},
)
organization_row = LiteLLM_OrganizationTable(
**data.json(exclude_none=True),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
new_organization_row = prisma_client.jsonify_object(
organization_row.json(exclude_none=True)
)
response = await prisma_client.db.litellm_organizationtable.create(
data={
**new_organization_row, # type: ignore
}
)
return response
@router.post(
"/organization/update",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_organization():
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
pass
@router.post(
"/organization/delete",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_organization():
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
pass
@router.post(
"/organization/info",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_organization(data: OrganizationRequest):
"""
Get the org specific information
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.organizations) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of organization id's to query. Passed in={data.organizations}"
},
)
response = await prisma_client.db.litellm_organizationtable.find_many(
where={"organization_id": {"in": data.organizations}},
include={"litellm_budget_table": True},
)
return response
#### BUDGET TABLE MANAGEMENT ####
@router.post(
"/budget/info",
tags=["organization management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_budget(data: BudgetRequest):
"""
Get the budget id specific information
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.budgets) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
},
)
response = await prisma_client.db.litellm_budgettable.find_many(
where={"budget_id": {"in": data.budgets}},
)
return response
#### MODEL MANAGEMENT ####
@ -6537,30 +6769,47 @@ async def health_services_endpoint(
Used by the UI to let user check if slack alerting is working as expected.
"""
global general_settings, proxy_logging_obj
try:
global general_settings, proxy_logging_obj
if service is None:
raise HTTPException(
status_code=400, detail={"error": "Service must be specified."}
)
if service is None:
raise HTTPException(
status_code=400, detail={"error": "Service must be specified."}
)
if service not in ["slack_budget_alerts"]:
raise HTTPException(
status_code=400,
detail={
"error": f"Service must be in list. Service={service}. List={['slack_budget_alerts']}"
},
)
if service not in ["slack_budget_alerts"]:
raise HTTPException(
status_code=400,
detail={
"error": f"Service must be in list. Service={service}. List={['slack_budget_alerts']}"
},
)
test_message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` my-secret-project \n`Expected Day of Error`: 28th March \n`Current Spend`: 100 \n`Projected Spend at end of month`: 1000 \n
"""
if "slack" in general_settings.get("alerting", []):
await proxy_logging_obj.alerting_handler(message=test_message, level="Low")
else:
raise HTTPException(
status_code=422,
detail={"error": "No slack connection setup. Unable to test this."},
if "slack" in general_settings.get("alerting", []):
test_message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` litellm-ui-test-alert \n`Expected Day of Error`: 28th March \n`Current Spend`: $100.00 \n`Projected Spend at end of month`: $1000.00 \n`Soft Limit`: $700"""
await proxy_logging_obj.alerting_handler(message=test_message, level="Low")
else:
raise HTTPException(
status_code=422,
detail={
"error": '"slack" not in proxy config: general_settings. Unable to test this.'
},
)
except Exception as e:
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_401_UNAUTHORIZED,
)