feat(proxy/utils.py): enable background process to reset key budgets

This commit is contained in:
Krrish Dholakia 2024-01-23 12:33:13 -08:00
parent 01a2514b98
commit b1a105e309
5 changed files with 138 additions and 18 deletions

View file

@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase):
metadata: Optional[dict] = {}
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
class UpdateKeyRequest(LiteLLMBase):

View file

@ -19,6 +19,7 @@ try:
import yaml
import orjson
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
except ImportError as e:
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
@ -73,6 +74,7 @@ from litellm.proxy.utils import (
_cache_user_row,
send_email,
get_logging_payload,
reset_budget,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic
@ -1125,6 +1127,7 @@ async def generate_key_helper_fn(
config: dict,
spend: float,
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
key_budget_duration: Optional[str] = None,
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
token: Optional[str] = None,
user_id: Optional[str] = None,
@ -1170,6 +1173,12 @@ async def generate_key_helper_fn(
duration_s = _duration_in_seconds(duration=duration)
expires = datetime.utcnow() + timedelta(seconds=duration_s)
if key_budget_duration is None: # one-time budget
key_reset_at = None
else:
duration_s = _duration_in_seconds(duration=key_budget_duration)
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
metadata_json = json.dumps(metadata)
@ -1205,6 +1214,8 @@ async def generate_key_helper_fn(
"metadata": metadata_json,
"tpm_limit": tpm_limit,
"rpm_limit": rpm_limit,
"budget_duration": key_budget_duration,
"budget_reset_at": key_reset_at,
}
if prisma_client is not None:
## CREATE USER (If necessary)
@ -1511,6 +1522,11 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
### START BUDGET SCHEDULER ###
scheduler = AsyncIOScheduler()
scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client])
scheduler.start()
#### API ENDPOINTS ####
@router.get(
@ -2186,6 +2202,9 @@ async def generate_key_fn(
if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None)
if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(
key=response["token"], expires=response["expires"], user_id=response["user_id"]

View file

@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
tpm_limit BigInt?
rpm_limit BigInt?
max_budget Float? @default(0.0)
budget_duration String?
budget_reset_at DateTime?
}
model LiteLLM_Config {

View file

@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status
import smtplib
import smtplib, re
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime
from datetime import datetime, timedelta
def print_verbose(print_statement):
@ -363,6 +363,8 @@ class PrismaClient:
user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None,
):
try:
print_verbose("PrismaClient: get_data")
@ -391,6 +393,24 @@ class PrismaClient:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif (
query_type == "find_all"
and expires is not None
and reset_at is not None
):
response = await self.db.litellm_verificationtoken.find_many(
where={
"OR": [
{"expires": None},
{"expires": {"gt": expires}},
],
"budget_reset_at": {"lt": reset_at},
}
)
if response is not None and len(response) > 0:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
print_verbose(f"PrismaClient: response={response}")
if response is not None:
return response
@ -517,7 +537,10 @@ class PrismaClient:
self,
token: Optional[str] = None,
data: dict = {},
data_list: Optional[List] = None,
user_id: Optional[str] = None,
query_type: Literal["update", "update_many"] = "update",
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
):
"""
Update existing data
@ -526,20 +549,21 @@ class PrismaClient:
db_data = self.jsonify_object(data=data)
if token is not None:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose(
"\033[91m"
+ f"DB Token Table update succeeded {response}"
+ "\033[0m"
)
return {"token": token, "data": db_data}
if query_type == "update":
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose(
"\033[91m"
+ f"DB Token Table update succeeded {response}"
+ "\033[0m"
)
return {"token": token, "data": db_data}
elif user_id is not None:
"""
If data['spend'] + data['user'], update the user table with spend info as well
@ -566,6 +590,33 @@ class PrismaClient:
+ "\033[0m"
)
return {"user_id": user_id, "data": db_data}
elif (
table_name is not None
and table_name == "key"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for idx, t in enumerate(data_list):
# check if plain text or hash
if t.token.startswith("sk-"): # type: ignore
t.token = self.hash_token(token=t.token) # type: ignore
try:
data_json = self.jsonify_object(data=t.model_dump())
except:
data_json = self.jsonify_object(data=t.dict())
batcher.litellm_verificationtoken.update(
where={"token": t.token}, # type: ignore
data={**data_json}, # type: ignore
)
await batcher.commit()
print_verbose(
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
payload[param] = str(payload[param])
return payload
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need budget to be reset
Resets their budget
Updates db
"""
if prisma_client is not None:
now = datetime.utcnow()
keys_to_reset = await prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
if len(keys_to_reset) > 0:
await prisma_client.update_data(
query_type="update_many", data_list=keys_to_reset, table_name="key"
)

View file

@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
tpm_limit BigInt?
rpm_limit BigInt?
max_budget Float? @default(0.0)
budget_duration String?
budget_reset_at DateTime?
}
model LiteLLM_Config {
@ -43,8 +45,8 @@ model LiteLLM_Config {
model LiteLLM_SpendLogs {
request_id String @unique
api_key String @default ("")
call_type String
api_key String @default ("")
spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
@ -56,4 +58,4 @@ model LiteLLM_SpendLogs {
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}
}