feat(proxy/utils.py): enable background process to reset key budgets

This commit is contained in:
Krrish Dholakia 2024-01-23 12:33:13 -08:00
parent 01a2514b98
commit b1a105e309
5 changed files with 138 additions and 18 deletions

View file

@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase):
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
class UpdateKeyRequest(LiteLLMBase): class UpdateKeyRequest(LiteLLMBase):

View file

@ -19,6 +19,7 @@ try:
import yaml import yaml
import orjson import orjson
import logging import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
except ImportError as e: except ImportError as e:
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`") raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
@ -73,6 +74,7 @@ from litellm.proxy.utils import (
_cache_user_row, _cache_user_row,
send_email, send_email,
get_logging_payload, get_logging_payload,
reset_budget,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic import pydantic
@ -1125,6 +1127,7 @@ async def generate_key_helper_fn(
config: dict, config: dict,
spend: float, spend: float,
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
key_budget_duration: Optional[str] = None,
max_budget: Optional[float] = None, # max_budget is used to Budget Per user max_budget: Optional[float] = None, # max_budget is used to Budget Per user
token: Optional[str] = None, token: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
@ -1170,6 +1173,12 @@ async def generate_key_helper_fn(
duration_s = _duration_in_seconds(duration=duration) duration_s = _duration_in_seconds(duration=duration)
expires = datetime.utcnow() + timedelta(seconds=duration_s) expires = datetime.utcnow() + timedelta(seconds=duration_s)
if key_budget_duration is None: # one-time budget
key_reset_at = None
else:
duration_s = _duration_in_seconds(duration=key_budget_duration)
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases) aliases_json = json.dumps(aliases)
config_json = json.dumps(config) config_json = json.dumps(config)
metadata_json = json.dumps(metadata) metadata_json = json.dumps(metadata)
@ -1205,6 +1214,8 @@ async def generate_key_helper_fn(
"metadata": metadata_json, "metadata": metadata_json,
"tpm_limit": tpm_limit, "tpm_limit": tpm_limit,
"rpm_limit": rpm_limit, "rpm_limit": rpm_limit,
"budget_duration": key_budget_duration,
"budget_reset_at": key_reset_at,
} }
if prisma_client is not None: if prisma_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
@ -1511,6 +1522,11 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
### START BUDGET SCHEDULER ###
scheduler = AsyncIOScheduler()
scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client])
scheduler.start()
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get( @router.get(
@ -2186,6 +2202,9 @@ async def generate_key_fn(
if "max_budget" in data_json: if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None) data_json["key_max_budget"] = data_json.pop("max_budget", None)
if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse( return GenerateKeyResponse(
key=response["token"], expires=response["expires"], user_id=response["user_id"] key=response["token"], expires=response["expires"], user_id=response["user_id"]

View file

@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
tpm_limit BigInt? tpm_limit BigInt?
rpm_limit BigInt? rpm_limit BigInt?
max_budget Float? @default(0.0) max_budget Float? @default(0.0)
budget_duration String?
budget_reset_at DateTime?
} }
model LiteLLM_Config { model LiteLLM_Config {

View file

@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status from fastapi import HTTPException, status
import smtplib import smtplib, re
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from datetime import datetime from datetime import datetime, timedelta
def print_verbose(print_statement): def print_verbose(print_statement):
@ -363,6 +363,8 @@ class PrismaClient:
user_id: Optional[str] = None, user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None, table_name: Optional[Literal["user", "key", "config"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique", query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None,
): ):
try: try:
print_verbose("PrismaClient: get_data") print_verbose("PrismaClient: get_data")
@ -391,6 +393,24 @@ class PrismaClient:
for r in response: for r in response:
if isinstance(r.expires, datetime): if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat() r.expires = r.expires.isoformat()
elif (
query_type == "find_all"
and expires is not None
and reset_at is not None
):
response = await self.db.litellm_verificationtoken.find_many(
where={
"OR": [
{"expires": None},
{"expires": {"gt": expires}},
],
"budget_reset_at": {"lt": reset_at},
}
)
if response is not None and len(response) > 0:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
print_verbose(f"PrismaClient: response={response}") print_verbose(f"PrismaClient: response={response}")
if response is not None: if response is not None:
return response return response
@ -517,7 +537,10 @@ class PrismaClient:
self, self,
token: Optional[str] = None, token: Optional[str] = None,
data: dict = {}, data: dict = {},
data_list: Optional[List] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
query_type: Literal["update", "update_many"] = "update",
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
): ):
""" """
Update existing data Update existing data
@ -526,20 +549,21 @@ class PrismaClient:
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
if token is not None: if token is not None:
print_verbose(f"token: {token}") print_verbose(f"token: {token}")
# check if plain text or hash if query_type == "update":
if token.startswith("sk-"): # check if plain text or hash
token = self.hash_token(token=token) if token.startswith("sk-"):
db_data["token"] = token token = self.hash_token(token=token)
response = await self.db.litellm_verificationtoken.update( db_data["token"] = token
where={"token": token}, # type: ignore response = await self.db.litellm_verificationtoken.update(
data={**db_data}, # type: ignore where={"token": token}, # type: ignore
) data={**db_data}, # type: ignore
print_verbose( )
"\033[91m" print_verbose(
+ f"DB Token Table update succeeded {response}" "\033[91m"
+ "\033[0m" + f"DB Token Table update succeeded {response}"
) + "\033[0m"
return {"token": token, "data": db_data} )
return {"token": token, "data": db_data}
elif user_id is not None: elif user_id is not None:
""" """
If data['spend'] + data['user'], update the user table with spend info as well If data['spend'] + data['user'], update the user table with spend info as well
@ -566,6 +590,33 @@ class PrismaClient:
+ "\033[0m" + "\033[0m"
) )
return {"user_id": user_id, "data": db_data} return {"user_id": user_id, "data": db_data}
elif (
table_name is not None
and table_name == "key"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for idx, t in enumerate(data_list):
# check if plain text or hash
if t.token.startswith("sk-"): # type: ignore
t.token = self.hash_token(token=t.token) # type: ignore
try:
data_json = self.jsonify_object(data=t.model_dump())
except:
data_json = self.jsonify_object(data=t.dict())
batcher.litellm_verificationtoken.update(
where={"token": t.token}, # type: ignore
data={**data_json}, # type: ignore
)
await batcher.commit()
print_verbose(
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
)
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
payload[param] = str(payload[param]) payload[param] = str(payload[param])
return payload return payload
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need budget to be reset
Resets their budget
Updates db
"""
if prisma_client is not None:
now = datetime.utcnow()
keys_to_reset = await prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
if len(keys_to_reset) > 0:
await prisma_client.update_data(
query_type="update_many", data_list=keys_to_reset, table_name="key"
)

View file

@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
tpm_limit BigInt? tpm_limit BigInt?
rpm_limit BigInt? rpm_limit BigInt?
max_budget Float? @default(0.0) max_budget Float? @default(0.0)
budget_duration String?
budget_reset_at DateTime?
} }
model LiteLLM_Config { model LiteLLM_Config {
@ -43,8 +45,8 @@ model LiteLLM_Config {
model LiteLLM_SpendLogs { model LiteLLM_SpendLogs {
request_id String @unique request_id String @unique
api_key String @default ("")
call_type String call_type String
api_key String @default ("")
spend Float @default(0.0) spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field endTime DateTime // Assuming end_time is a DateTime field
@ -56,4 +58,4 @@ model LiteLLM_SpendLogs {
usage Json @default("{}") usage Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
cache_hit String @default("") cache_hit String @default("")
} }