From b1a105e309b210164a127429300ec5b2eec1b04f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 23 Jan 2024 12:33:13 -0800 Subject: [PATCH] feat(proxy/utils.py): enable background process to reset key budgets --- litellm/proxy/_types.py | 1 + litellm/proxy/proxy_server.py | 19 +++++ litellm/proxy/schema.prisma | 2 + litellm/proxy/utils.py | 128 +++++++++++++++++++++++++++++----- schema.prisma | 6 +- 5 files changed, 138 insertions(+), 18 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index bb56ad6bf..d5dc841cb 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase): metadata: Optional[dict] = {} tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None + budget_duration: Optional[str] = None class UpdateKeyRequest(LiteLLMBase): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 78e756a2a..398905e1a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -19,6 +19,7 @@ try: import yaml import orjson import logging + from apscheduler.schedulers.asyncio import AsyncIOScheduler except ImportError as e: raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`") @@ -73,6 +74,7 @@ from litellm.proxy.utils import ( _cache_user_row, send_email, get_logging_payload, + reset_budget, ) from litellm.proxy.secret_managers.google_kms import load_google_kms import pydantic @@ -1125,6 +1127,7 @@ async def generate_key_helper_fn( config: dict, spend: float, key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key + key_budget_duration: Optional[str] = None, max_budget: Optional[float] = None, # max_budget is used to Budget Per user token: Optional[str] = None, user_id: Optional[str] = None, @@ -1170,6 +1173,12 @@ async def generate_key_helper_fn( duration_s = _duration_in_seconds(duration=duration) expires = datetime.utcnow() + timedelta(seconds=duration_s) + if key_budget_duration is None: # one-time budget + key_reset_at = None + else: + duration_s = _duration_in_seconds(duration=key_budget_duration) + key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s) + aliases_json = json.dumps(aliases) config_json = json.dumps(config) metadata_json = json.dumps(metadata) @@ -1205,6 +1214,8 @@ async def generate_key_helper_fn( "metadata": metadata_json, "tpm_limit": tpm_limit, "rpm_limit": rpm_limit, + "budget_duration": key_budget_duration, + "budget_reset_at": key_reset_at, } if prisma_client is not None: ## CREATE USER (If necessary) @@ -1511,6 +1522,11 @@ async def startup_event(): duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + ### START BUDGET SCHEDULER ### + scheduler = AsyncIOScheduler() + scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client]) + scheduler.start() + #### API ENDPOINTS #### @router.get( @@ -2186,6 +2202,9 @@ async def generate_key_fn( if "max_budget" in data_json: data_json["key_max_budget"] = data_json.pop("max_budget", None) + if "budget_duration" in data_json: + data_json["key_budget_duration"] = data_json.pop("budget_duration", None) + response = await generate_key_helper_fn(**data_json) return GenerateKeyResponse( key=response["token"], expires=response["expires"], user_id=response["user_id"] diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 931a15812..ea3bade8c 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -34,6 +34,8 @@ model LiteLLM_VerificationToken { tpm_limit BigInt? rpm_limit BigInt? max_budget Float? @default(0.0) + budget_duration String? + budget_reset_at DateTime? } model LiteLLM_Config { diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c19137d57..109141079 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.proxy.db.base_client import CustomDB from litellm._logging import verbose_proxy_logger from fastapi import HTTPException, status -import smtplib +import smtplib, re from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart -from datetime import datetime +from datetime import datetime, timedelta def print_verbose(print_statement): @@ -363,6 +363,8 @@ class PrismaClient: user_id: Optional[str] = None, table_name: Optional[Literal["user", "key", "config"]] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", + expires: Optional[datetime] = None, + reset_at: Optional[datetime] = None, ): try: print_verbose("PrismaClient: get_data") @@ -391,6 +393,24 @@ class PrismaClient: for r in response: if isinstance(r.expires, datetime): r.expires = r.expires.isoformat() + elif ( + query_type == "find_all" + and expires is not None + and reset_at is not None + ): + response = await self.db.litellm_verificationtoken.find_many( + where={ + "OR": [ + {"expires": None}, + {"expires": {"gt": expires}}, + ], + "budget_reset_at": {"lt": reset_at}, + } + ) + if response is not None and len(response) > 0: + for r in response: + if isinstance(r.expires, datetime): + r.expires = r.expires.isoformat() print_verbose(f"PrismaClient: response={response}") if response is not None: return response @@ -517,7 +537,10 @@ class PrismaClient: self, token: Optional[str] = None, data: dict = {}, + data_list: Optional[List] = None, user_id: Optional[str] = None, + query_type: Literal["update", "update_many"] = "update", + table_name: Optional[Literal["user", "key", "config", "spend"]] = None, ): """ Update existing data @@ -526,20 +549,21 @@ class PrismaClient: db_data = self.jsonify_object(data=data) if token is not None: print_verbose(f"token: {token}") - # check if plain text or hash - if token.startswith("sk-"): - token = self.hash_token(token=token) - db_data["token"] = token - response = await self.db.litellm_verificationtoken.update( - where={"token": token}, # type: ignore - data={**db_data}, # type: ignore - ) - print_verbose( - "\033[91m" - + f"DB Token Table update succeeded {response}" - + "\033[0m" - ) - return {"token": token, "data": db_data} + if query_type == "update": + # check if plain text or hash + if token.startswith("sk-"): + token = self.hash_token(token=token) + db_data["token"] = token + response = await self.db.litellm_verificationtoken.update( + where={"token": token}, # type: ignore + data={**db_data}, # type: ignore + ) + print_verbose( + "\033[91m" + + f"DB Token Table update succeeded {response}" + + "\033[0m" + ) + return {"token": token, "data": db_data} elif user_id is not None: """ If data['spend'] + data['user'], update the user table with spend info as well @@ -566,6 +590,33 @@ class PrismaClient: + "\033[0m" ) return {"user_id": user_id, "data": db_data} + elif ( + table_name is not None + and table_name == "key" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for idx, t in enumerate(data_list): + # check if plain text or hash + if t.token.startswith("sk-"): # type: ignore + t.token = self.hash_token(token=t.token) # type: ignore + try: + data_json = self.jsonify_object(data=t.model_dump()) + except: + data_json = self.jsonify_object(data=t.dict()) + batcher.litellm_verificationtoken.update( + where={"token": t.token}, # type: ignore + data={**data_json}, # type: ignore + ) + await batcher.commit() + print_verbose( + "\033[91m" + f"DB Token Table update succeeded" + "\033[0m" + ) except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) @@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time): payload[param] = str(payload[param]) return payload + + +def _duration_in_seconds(duration: str): + match = re.match(r"(\d+)([smhd]?)", duration) + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + else: + raise ValueError("Unsupported duration unit") + + +async def reset_budget(prisma_client: PrismaClient): + """ + Gets all the non-expired keys for a db, which need budget to be reset + + Resets their budget + + Updates db + """ + if prisma_client is not None: + now = datetime.utcnow() + keys_to_reset = await prisma_client.get_data( + table_name="key", query_type="find_all", expires=now, reset_at=now + ) + + for key in keys_to_reset: + key.spend = 0.0 + duration_s = _duration_in_seconds(duration=key.budget_duration) + key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s) + + if len(keys_to_reset) > 0: + await prisma_client.update_data( + query_type="update_many", data_list=keys_to_reset, table_name="key" + ) diff --git a/schema.prisma b/schema.prisma index 1212b0c66..ea3bade8c 100644 --- a/schema.prisma +++ b/schema.prisma @@ -34,6 +34,8 @@ model LiteLLM_VerificationToken { tpm_limit BigInt? rpm_limit BigInt? max_budget Float? @default(0.0) + budget_duration String? + budget_reset_at DateTime? } model LiteLLM_Config { @@ -43,8 +45,8 @@ model LiteLLM_Config { model LiteLLM_SpendLogs { request_id String @unique - api_key String @default ("") call_type String + api_key String @default ("") spend Float @default(0.0) startTime DateTime // Assuming start_time is a DateTime field endTime DateTime // Assuming end_time is a DateTime field @@ -56,4 +58,4 @@ model LiteLLM_SpendLogs { usage Json @default("{}") metadata Json @default("{}") cache_hit String @default("") -} +} \ No newline at end of file