Merge branch 'main' into litellm_organization_table

This commit is contained in:
Krish Dholakia 2024-03-02 16:09:28 -08:00 committed by GitHub
commit eaccbf26b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 146 additions and 75 deletions

View file

@ -791,6 +791,7 @@ async def user_api_key_auth(
"/global/spend/keys",
"/global/spend/models",
"/global/predict/spend/logs",
"/health/services",
]
# check if the current route startswith any of the allowed routes
if (
@ -1814,6 +1815,9 @@ async def generate_key_helper_fn(
spend: float,
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
key_budget_duration: Optional[str] = None,
key_soft_budget: Optional[
float
] = None, # key_soft_budget is used to Budget Per key
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
token: Optional[str] = None,
@ -1873,6 +1877,19 @@ async def generate_key_helper_fn(
rpm_limit = rpm_limit
allowed_cache_controls = allowed_cache_controls
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
if prisma_client is not None:
# create the Budget Row for the LiteLLM Verification Token
budget_row = LiteLLM_BudgetTable(
soft_budget=key_soft_budget or litellm.default_soft_budget,
model_max_budget=model_max_budget or {},
created_by=user_id,
updated_by=user_id,
)
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(data={**new_budget}) # type: ignore
_budget_id = getattr(_budget, "id", None)
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {
@ -1910,6 +1927,7 @@ async def generate_key_helper_fn(
"allowed_cache_controls": allowed_cache_controls,
"permissions": permissions_json,
"model_max_budget": model_max_budget_json,
"budget_id": _budget_id,
}
if (
general_settings.get("allow_user_auth", False) == True
@ -1982,6 +2000,9 @@ async def generate_key_helper_fn(
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Add budget related info in key_data - this ensures it's returned
key_data["soft_budget"] = key_soft_budget
return key_data
@ -2142,14 +2163,6 @@ async def async_data_generator(response, user_api_key_dict):
except Exception as e:
yield f"data: {str(e)}\n\n"
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
# Streaming is done, yield the [DONE] chunk
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
@ -2497,14 +2510,6 @@ async def completion(
headers=custom_headers,
)
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
@ -2703,14 +2708,6 @@ async def chat_completion(
headers=custom_headers,
)
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
fastapi_response.headers["x-litellm-model-id"] = model_id
### CALL HOOKS ### - modify outgoing data
@ -2918,12 +2915,6 @@ async def embeddings(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
@ -3069,12 +3060,6 @@ async def image_generation(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
@ -3228,12 +3213,6 @@ async def moderations(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
@ -3378,6 +3357,8 @@ async def generate_key_fn(
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None)
if "soft_budget" in data_json:
data_json["key_soft_budget"] = data_json.pop("soft_budget", None)
if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
@ -6722,6 +6703,50 @@ async def test_endpoint(request: Request):
return {"route": request.url.path}
@router.get(
"/health/services",
tags=["health"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def health_services_endpoint(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
service: Literal["slack_budget_alerts"] = fastapi.Query(
description="Specify the service being hit."
),
):
"""
Hidden endpoint.
Used by the UI to let user check if slack alerting is working as expected.
"""
global general_settings, proxy_logging_obj
if service is None:
raise HTTPException(
status_code=400, detail={"error": "Service must be specified."}
)
if service not in ["slack_budget_alerts"]:
raise HTTPException(
status_code=400,
detail={
"error": f"Service must be in list. Service={service}. List={['slack_budget_alerts']}"
},
)
test_message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` my-secret-project \n`Expected Day of Error`: 28th March \n`Current Spend`: 100 \n`Projected Spend at end of month`: 1000 \n
"""
if "slack" in general_settings.get("alerting", []):
await proxy_logging_obj.alerting_handler(message=test_message, level="Low")
else:
raise HTTPException(
status_code=422,
detail={"error": "No slack connection setup. Unable to test this."},
)
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
async def health_endpoint(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),