forked from phoenix/litellm-mirror
feat(proxy/utils.py): enable background process to reset key budgets
This commit is contained in:
parent
01a2514b98
commit
b1a105e309
5 changed files with 138 additions and 18 deletions
|
@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase):
|
|||
metadata: Optional[dict] = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateKeyRequest(LiteLLMBase):
|
||||
|
|
|
@ -19,6 +19,7 @@ try:
|
|||
import yaml
|
||||
import orjson
|
||||
import logging
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
||||
|
||||
|
@ -73,6 +74,7 @@ from litellm.proxy.utils import (
|
|||
_cache_user_row,
|
||||
send_email,
|
||||
get_logging_payload,
|
||||
reset_budget,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
import pydantic
|
||||
|
@ -1125,6 +1127,7 @@ async def generate_key_helper_fn(
|
|||
config: dict,
|
||||
spend: float,
|
||||
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
||||
key_budget_duration: Optional[str] = None,
|
||||
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
|
||||
token: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
|
@ -1170,6 +1173,12 @@ async def generate_key_helper_fn(
|
|||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
if key_budget_duration is None: # one-time budget
|
||||
key_reset_at = None
|
||||
else:
|
||||
duration_s = _duration_in_seconds(duration=key_budget_duration)
|
||||
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
metadata_json = json.dumps(metadata)
|
||||
|
@ -1205,6 +1214,8 @@ async def generate_key_helper_fn(
|
|||
"metadata": metadata_json,
|
||||
"tpm_limit": tpm_limit,
|
||||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": key_budget_duration,
|
||||
"budget_reset_at": key_reset_at,
|
||||
}
|
||||
if prisma_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
|
@ -1511,6 +1522,11 @@ async def startup_event():
|
|||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||
)
|
||||
|
||||
### START BUDGET SCHEDULER ###
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client])
|
||||
scheduler.start()
|
||||
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get(
|
||||
|
@ -2186,6 +2202,9 @@ async def generate_key_fn(
|
|||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
|
||||
if "budget_duration" in data_json:
|
||||
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
|
||||
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(
|
||||
key=response["token"], expires=response["expires"], user_id=response["user_id"]
|
||||
|
|
|
@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
|
|||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
max_budget Float? @default(0.0)
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
|
|
|
@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from fastapi import HTTPException, status
|
||||
import smtplib
|
||||
import smtplib, re
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -363,6 +363,8 @@ class PrismaClient:
|
|||
user_id: Optional[str] = None,
|
||||
table_name: Optional[Literal["user", "key", "config"]] = None,
|
||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||
expires: Optional[datetime] = None,
|
||||
reset_at: Optional[datetime] = None,
|
||||
):
|
||||
try:
|
||||
print_verbose("PrismaClient: get_data")
|
||||
|
@ -391,6 +393,24 @@ class PrismaClient:
|
|||
for r in response:
|
||||
if isinstance(r.expires, datetime):
|
||||
r.expires = r.expires.isoformat()
|
||||
elif (
|
||||
query_type == "find_all"
|
||||
and expires is not None
|
||||
and reset_at is not None
|
||||
):
|
||||
response = await self.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"expires": None},
|
||||
{"expires": {"gt": expires}},
|
||||
],
|
||||
"budget_reset_at": {"lt": reset_at},
|
||||
}
|
||||
)
|
||||
if response is not None and len(response) > 0:
|
||||
for r in response:
|
||||
if isinstance(r.expires, datetime):
|
||||
r.expires = r.expires.isoformat()
|
||||
print_verbose(f"PrismaClient: response={response}")
|
||||
if response is not None:
|
||||
return response
|
||||
|
@ -517,7 +537,10 @@ class PrismaClient:
|
|||
self,
|
||||
token: Optional[str] = None,
|
||||
data: dict = {},
|
||||
data_list: Optional[List] = None,
|
||||
user_id: Optional[str] = None,
|
||||
query_type: Literal["update", "update_many"] = "update",
|
||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||
):
|
||||
"""
|
||||
Update existing data
|
||||
|
@ -526,20 +549,21 @@ class PrismaClient:
|
|||
db_data = self.jsonify_object(data=data)
|
||||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose(
|
||||
"\033[91m"
|
||||
+ f"DB Token Table update succeeded {response}"
|
||||
+ "\033[0m"
|
||||
)
|
||||
return {"token": token, "data": db_data}
|
||||
if query_type == "update":
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose(
|
||||
"\033[91m"
|
||||
+ f"DB Token Table update succeeded {response}"
|
||||
+ "\033[0m"
|
||||
)
|
||||
return {"token": token, "data": db_data}
|
||||
elif user_id is not None:
|
||||
"""
|
||||
If data['spend'] + data['user'], update the user table with spend info as well
|
||||
|
@ -566,6 +590,33 @@ class PrismaClient:
|
|||
+ "\033[0m"
|
||||
)
|
||||
return {"user_id": user_id, "data": db_data}
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "key"
|
||||
and query_type == "update_many"
|
||||
and data_list is not None
|
||||
and isinstance(data_list, list)
|
||||
):
|
||||
"""
|
||||
Batch write update queries
|
||||
"""
|
||||
batcher = self.db.batch_()
|
||||
for idx, t in enumerate(data_list):
|
||||
# check if plain text or hash
|
||||
if t.token.startswith("sk-"): # type: ignore
|
||||
t.token = self.hash_token(token=t.token) # type: ignore
|
||||
try:
|
||||
data_json = self.jsonify_object(data=t.model_dump())
|
||||
except:
|
||||
data_json = self.jsonify_object(data=t.dict())
|
||||
batcher.litellm_verificationtoken.update(
|
||||
where={"token": t.token}, # type: ignore
|
||||
data={**data_json}, # type: ignore
|
||||
)
|
||||
await batcher.commit()
|
||||
print_verbose(
|
||||
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
|
||||
)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
payload[param] = str(payload[param])
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
|
||||
async def reset_budget(prisma_client: PrismaClient):
|
||||
"""
|
||||
Gets all the non-expired keys for a db, which need budget to be reset
|
||||
|
||||
Resets their budget
|
||||
|
||||
Updates db
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
now = datetime.utcnow()
|
||||
keys_to_reset = await prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
|
||||
for key in keys_to_reset:
|
||||
key.spend = 0.0
|
||||
duration_s = _duration_in_seconds(duration=key.budget_duration)
|
||||
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
|
||||
|
||||
if len(keys_to_reset) > 0:
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=keys_to_reset, table_name="key"
|
||||
)
|
||||
|
|
|
@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
|
|||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
max_budget Float? @default(0.0)
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
|
@ -43,8 +45,8 @@ model LiteLLM_Config {
|
|||
|
||||
model LiteLLM_SpendLogs {
|
||||
request_id String @unique
|
||||
api_key String @default ("")
|
||||
call_type String
|
||||
api_key String @default ("")
|
||||
spend Float @default(0.0)
|
||||
startTime DateTime // Assuming start_time is a DateTime field
|
||||
endTime DateTime // Assuming end_time is a DateTime field
|
||||
|
@ -56,4 +58,4 @@ model LiteLLM_SpendLogs {
|
|||
usage Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
cache_hit String @default("")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue