From ac085a4643291c7b920c30597ac12ebeabcd5607 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 2 Mar 2024 18:34:18 -0800 Subject: [PATCH] fix(proxy_server.py): actual implementation of slack soft budget alerting --- litellm/proxy/_types.py | 1 + litellm/proxy/proxy_server.py | 50 +++++++++++++++++-- litellm/proxy/schema.prisma | 3 ++ litellm/proxy/utils.py | 94 +++++++++++++++++++++++++++++++++-- schema.prisma | 3 ++ 5 files changed, 142 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6196f18a2..5cac7c9ad 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -509,6 +509,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): team_tpm_limit: Optional[int] = None team_rpm_limit: Optional[int] = None team_max_budget: Optional[float] = None + soft_budget: Optional[float] = None class UserAPIKeyAuth( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index eca5fb30a..b0bd20c19 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -94,6 +94,8 @@ from litellm.proxy.utils import ( _read_request_body, _is_valid_team_configs, _is_user_proxy_admin, + _is_projected_spend_over_limit, + _get_projected_spend_over_limit, ) from litellm.proxy.secret_managers.google_kms import load_google_kms import pydantic @@ -360,7 +362,6 @@ async def user_api_key_auth( valid_token = await prisma_client.get_data( token=api_key, table_name="combined_view" ) - elif custom_db_client is not None: try: valid_token = await custom_db_client.get_data( @@ -932,6 +933,7 @@ async def _PROXY_track_cost_callback( f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}" ) if user_api_key is not None: + ## UPDATE DATABASE await update_database( token=user_api_key, response_cost=response_cost, @@ -1085,6 +1087,39 @@ async def update_database( # Calculate the new cost by adding the existing cost and response_cost new_spend = existing_spend + response_cost + ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT + soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown + if existing_spend_obj.soft_budget_cooldown == False and ( + _is_projected_spend_over_limit( + current_spend=new_spend, + soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, + ) + == True + ): + key_alias = existing_spend_obj.key_alias + projected_spend, projected_exceeded_date = ( + _get_projected_spend_over_limit( + current_spend=new_spend, + soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, + ) + ) + soft_limit = existing_spend_obj.litellm_budget_table.soft_budget + user_info = { + "key_alias": key_alias, + "projected_spend": projected_spend, + "projected_exceeded_date": projected_exceeded_date, + } + # alert user + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="projected_limit_exceeded", + user_info=user_info, + user_max_budget=soft_limit, + user_current_spend=new_spend, + ) + ) + # set cooldown on alert + soft_budget_cooldown = True # track cost per model, for the given key spend_per_model = existing_spend_obj.model_spend or {} current_model = kwargs.get("model") @@ -1100,7 +1135,11 @@ async def update_database( # Update the cost column for the given token await prisma_client.update_data( token=token, - data={"spend": new_spend, "model_spend": spend_per_model}, + data={ + "spend": new_spend, + "model_spend": spend_per_model, + "soft_budget_cooldown": soft_budget_cooldown, + }, ) valid_token = user_api_key_cache.get_cache(key=token) @@ -1874,17 +1913,18 @@ async def generate_key_helper_fn( allowed_cache_controls = allowed_cache_controls # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable - if prisma_client is not None: + _budget_id = None + if prisma_client is not None and key_soft_budget is not None: # create the Budget Row for the LiteLLM Verification Token budget_row = LiteLLM_BudgetTable( - soft_budget=key_soft_budget or litellm.default_soft_budget, + soft_budget=key_soft_budget, model_max_budget=model_max_budget or {}, created_by=user_id, updated_by=user_id, ) new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) _budget = await prisma_client.db.litellm_budgettable.create(data={**new_budget}) # type: ignore - _budget_id = getattr(_budget, "id", None) + _budget_id = getattr(_budget, "budget_id", None) try: # Create a new verification token (you may want to enhance this logic based on your needs) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 1fe55f24e..54e23e769 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -23,6 +23,7 @@ model LiteLLM_BudgetTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") updated_by String organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget + keys LiteLLM_VerificationToken[] // multiple keys can have the same budget } model LiteLLM_OrganizationTable { @@ -90,6 +91,7 @@ model LiteLLM_VerificationToken { token String @id key_name String? key_alias String? + soft_budget_cooldown Boolean @default(false) // key-level state on if budget alerts need to be cooled down spend Float @default(0.0) expires DateTime? models String[] @@ -109,6 +111,7 @@ model LiteLLM_VerificationToken { model_spend Json @default("{}") model_max_budget Json @default("{}") budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) } // store proxy config.yaml diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c67448c86..31e913e8b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -222,6 +222,7 @@ class ProxyLogging: "user_and_proxy_budget", "failed_budgets", "failed_tracking", + "projected_limit_exceeded", ], user_max_budget: float, user_current_spend: float, @@ -255,6 +256,23 @@ class ProxyLogging: level="High", ) return + elif type == "projected_limit_exceeded" and user_info is not None: + """ + Input variables: + user_info = { + "key_alias": key_alias, + "projected_spend": projected_spend, + "projected_exceeded_date": projected_exceeded_date, + } + user_max_budget=soft_limit, + user_current_spend=new_spend + """ + message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}""" + await self.alerting_handler( + message=message, + level="High", + ) + return else: user_info = str(user_info) # percent of max_budget left to spend @@ -748,7 +766,8 @@ class PrismaClient: detail={"error": f"No token passed in. Token={token}"}, ) response = await self.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} + where={"token": hashed_token}, + include={"litellm_budget_table": True}, ) if response is not None: # for prisma we need to cast the expires time to str @@ -758,7 +777,8 @@ class PrismaClient: response.expires = response.expires.isoformat() elif query_type == "find_all" and user_id is not None: response = await self.db.litellm_verificationtoken.find_many( - where={"user_id": user_id} + where={"user_id": user_id}, + include={"litellm_budget_table": True}, ) if response is not None and len(response) > 0: for r in response: @@ -766,7 +786,8 @@ class PrismaClient: r.expires = r.expires.isoformat() elif query_type == "find_all" and team_id is not None: response = await self.db.litellm_verificationtoken.find_many( - where={"team_id": team_id} + where={"team_id": team_id}, + include={"litellm_budget_table": True}, ) if response is not None and len(response) > 0: for r in response: @@ -809,7 +830,9 @@ class PrismaClient: hashed_tokens.append(t) where_filter["token"]["in"] = hashed_tokens response = await self.db.litellm_verificationtoken.find_many( - order={"spend": "desc"}, where=where_filter # type: ignore + order={"spend": "desc"}, + where=where_filter, # type: ignore + include={"litellm_budget_table": True}, ) if response is not None: return response @@ -1728,6 +1751,69 @@ async def _read_request_body(request): return {} +def _is_projected_spend_over_limit( + current_spend: float, soft_budget_limit: Optional[float] +): + from datetime import date + + if soft_budget_limit is None: + # If there's no limit, we can't exceed it. + return False + + today = date.today() + + # Finding the first day of the next month, then subtracting one day to get the end of the current month. + if today.month == 12: # December edge case + end_month = date(today.year + 1, 1, 1) - timedelta(days=1) + else: + end_month = date(today.year, today.month + 1, 1) - timedelta(days=1) + + remaining_days = (end_month - today).days + + # Check for the start of the month to avoid division by zero + if today.day == 1: + daily_spend_estimate = current_spend + else: + daily_spend_estimate = current_spend / (today.day - 1) + + # Total projected spend for the month + projected_spend = current_spend + (daily_spend_estimate * remaining_days) + + if projected_spend > soft_budget_limit: + print_verbose("Projected spend exceeds soft budget limit!") + return True + return False + + +def _get_projected_spend_over_limit( + current_spend: float, soft_budget_limit: Optional[float] +) -> Optional[tuple]: + import datetime + + if soft_budget_limit is None: + return None + + today = datetime.date.today() + end_month = datetime.date(today.year, today.month + 1, 1) - datetime.timedelta( + days=1 + ) + remaining_days = (end_month - today).days + + daily_spend = current_spend / ( + today.day - 1 + ) # assuming the current spend till today (not including today) + projected_spend = daily_spend * remaining_days + + if projected_spend > soft_budget_limit: + approx_days = soft_budget_limit / daily_spend + limit_exceed_date = today + datetime.timedelta(days=approx_days) + + # return the projected spend and the date it will exceeded + return projected_spend, limit_exceed_date + + return None + + def _is_valid_team_configs(team_id=None, team_config=None, request_data=None): if team_id is None or team_config is None or request_data is None: return diff --git a/schema.prisma b/schema.prisma index 1fe55f24e..54e23e769 100644 --- a/schema.prisma +++ b/schema.prisma @@ -23,6 +23,7 @@ model LiteLLM_BudgetTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") updated_by String organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget + keys LiteLLM_VerificationToken[] // multiple keys can have the same budget } model LiteLLM_OrganizationTable { @@ -90,6 +91,7 @@ model LiteLLM_VerificationToken { token String @id key_name String? key_alias String? + soft_budget_cooldown Boolean @default(false) // key-level state on if budget alerts need to be cooled down spend Float @default(0.0) expires DateTime? models String[] @@ -109,6 +111,7 @@ model LiteLLM_VerificationToken { model_spend Json @default("{}") model_max_budget Json @default("{}") budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) } // store proxy config.yaml