forked from phoenix/litellm-mirror
feat(proxy/utils.py): enable background process to reset key budgets
This commit is contained in:
parent
01a2514b98
commit
b1a105e309
5 changed files with 138 additions and 18 deletions
|
@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase):
|
||||||
metadata: Optional[dict] = {}
|
metadata: Optional[dict] = {}
|
||||||
tpm_limit: Optional[int] = None
|
tpm_limit: Optional[int] = None
|
||||||
rpm_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
|
budget_duration: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UpdateKeyRequest(LiteLLMBase):
|
class UpdateKeyRequest(LiteLLMBase):
|
||||||
|
|
|
@ -19,6 +19,7 @@ try:
|
||||||
import yaml
|
import yaml
|
||||||
import orjson
|
import orjson
|
||||||
import logging
|
import logging
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
||||||
|
|
||||||
|
@ -73,6 +74,7 @@ from litellm.proxy.utils import (
|
||||||
_cache_user_row,
|
_cache_user_row,
|
||||||
send_email,
|
send_email,
|
||||||
get_logging_payload,
|
get_logging_payload,
|
||||||
|
reset_budget,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
import pydantic
|
import pydantic
|
||||||
|
@ -1125,6 +1127,7 @@ async def generate_key_helper_fn(
|
||||||
config: dict,
|
config: dict,
|
||||||
spend: float,
|
spend: float,
|
||||||
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
||||||
|
key_budget_duration: Optional[str] = None,
|
||||||
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
|
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
|
@ -1170,6 +1173,12 @@ async def generate_key_helper_fn(
|
||||||
duration_s = _duration_in_seconds(duration=duration)
|
duration_s = _duration_in_seconds(duration=duration)
|
||||||
expires = datetime.utcnow() + timedelta(seconds=duration_s)
|
expires = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||||
|
|
||||||
|
if key_budget_duration is None: # one-time budget
|
||||||
|
key_reset_at = None
|
||||||
|
else:
|
||||||
|
duration_s = _duration_in_seconds(duration=key_budget_duration)
|
||||||
|
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||||
|
|
||||||
aliases_json = json.dumps(aliases)
|
aliases_json = json.dumps(aliases)
|
||||||
config_json = json.dumps(config)
|
config_json = json.dumps(config)
|
||||||
metadata_json = json.dumps(metadata)
|
metadata_json = json.dumps(metadata)
|
||||||
|
@ -1205,6 +1214,8 @@ async def generate_key_helper_fn(
|
||||||
"metadata": metadata_json,
|
"metadata": metadata_json,
|
||||||
"tpm_limit": tpm_limit,
|
"tpm_limit": tpm_limit,
|
||||||
"rpm_limit": rpm_limit,
|
"rpm_limit": rpm_limit,
|
||||||
|
"budget_duration": key_budget_duration,
|
||||||
|
"budget_reset_at": key_reset_at,
|
||||||
}
|
}
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
## CREATE USER (If necessary)
|
## CREATE USER (If necessary)
|
||||||
|
@ -1511,6 +1522,11 @@ async def startup_event():
|
||||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### START BUDGET SCHEDULER ###
|
||||||
|
scheduler = AsyncIOScheduler()
|
||||||
|
scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client])
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get(
|
@router.get(
|
||||||
|
@ -2186,6 +2202,9 @@ async def generate_key_fn(
|
||||||
if "max_budget" in data_json:
|
if "max_budget" in data_json:
|
||||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||||
|
|
||||||
|
if "budget_duration" in data_json:
|
||||||
|
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
|
||||||
|
|
||||||
response = await generate_key_helper_fn(**data_json)
|
response = await generate_key_helper_fn(**data_json)
|
||||||
return GenerateKeyResponse(
|
return GenerateKeyResponse(
|
||||||
key=response["token"], expires=response["expires"], user_id=response["user_id"]
|
key=response["token"], expires=response["expires"], user_id=response["user_id"]
|
||||||
|
|
|
@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
|
||||||
tpm_limit BigInt?
|
tpm_limit BigInt?
|
||||||
rpm_limit BigInt?
|
rpm_limit BigInt?
|
||||||
max_budget Float? @default(0.0)
|
max_budget Float? @default(0.0)
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
}
|
}
|
||||||
|
|
||||||
model LiteLLM_Config {
|
model LiteLLM_Config {
|
||||||
|
|
|
@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy.db.base_client import CustomDB
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
import smtplib
|
import smtplib, re
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -363,6 +363,8 @@ class PrismaClient:
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
table_name: Optional[Literal["user", "key", "config"]] = None,
|
table_name: Optional[Literal["user", "key", "config"]] = None,
|
||||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
|
expires: Optional[datetime] = None,
|
||||||
|
reset_at: Optional[datetime] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
print_verbose("PrismaClient: get_data")
|
print_verbose("PrismaClient: get_data")
|
||||||
|
@ -391,6 +393,24 @@ class PrismaClient:
|
||||||
for r in response:
|
for r in response:
|
||||||
if isinstance(r.expires, datetime):
|
if isinstance(r.expires, datetime):
|
||||||
r.expires = r.expires.isoformat()
|
r.expires = r.expires.isoformat()
|
||||||
|
elif (
|
||||||
|
query_type == "find_all"
|
||||||
|
and expires is not None
|
||||||
|
and reset_at is not None
|
||||||
|
):
|
||||||
|
response = await self.db.litellm_verificationtoken.find_many(
|
||||||
|
where={
|
||||||
|
"OR": [
|
||||||
|
{"expires": None},
|
||||||
|
{"expires": {"gt": expires}},
|
||||||
|
],
|
||||||
|
"budget_reset_at": {"lt": reset_at},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if response is not None and len(response) > 0:
|
||||||
|
for r in response:
|
||||||
|
if isinstance(r.expires, datetime):
|
||||||
|
r.expires = r.expires.isoformat()
|
||||||
print_verbose(f"PrismaClient: response={response}")
|
print_verbose(f"PrismaClient: response={response}")
|
||||||
if response is not None:
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
|
@ -517,7 +537,10 @@ class PrismaClient:
|
||||||
self,
|
self,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
data: dict = {},
|
data: dict = {},
|
||||||
|
data_list: Optional[List] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
|
query_type: Literal["update", "update_many"] = "update",
|
||||||
|
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update existing data
|
Update existing data
|
||||||
|
@ -526,20 +549,21 @@ class PrismaClient:
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
if token is not None:
|
if token is not None:
|
||||||
print_verbose(f"token: {token}")
|
print_verbose(f"token: {token}")
|
||||||
# check if plain text or hash
|
if query_type == "update":
|
||||||
if token.startswith("sk-"):
|
# check if plain text or hash
|
||||||
token = self.hash_token(token=token)
|
if token.startswith("sk-"):
|
||||||
db_data["token"] = token
|
token = self.hash_token(token=token)
|
||||||
response = await self.db.litellm_verificationtoken.update(
|
db_data["token"] = token
|
||||||
where={"token": token}, # type: ignore
|
response = await self.db.litellm_verificationtoken.update(
|
||||||
data={**db_data}, # type: ignore
|
where={"token": token}, # type: ignore
|
||||||
)
|
data={**db_data}, # type: ignore
|
||||||
print_verbose(
|
)
|
||||||
"\033[91m"
|
print_verbose(
|
||||||
+ f"DB Token Table update succeeded {response}"
|
"\033[91m"
|
||||||
+ "\033[0m"
|
+ f"DB Token Table update succeeded {response}"
|
||||||
)
|
+ "\033[0m"
|
||||||
return {"token": token, "data": db_data}
|
)
|
||||||
|
return {"token": token, "data": db_data}
|
||||||
elif user_id is not None:
|
elif user_id is not None:
|
||||||
"""
|
"""
|
||||||
If data['spend'] + data['user'], update the user table with spend info as well
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||||||
|
@ -566,6 +590,33 @@ class PrismaClient:
|
||||||
+ "\033[0m"
|
+ "\033[0m"
|
||||||
)
|
)
|
||||||
return {"user_id": user_id, "data": db_data}
|
return {"user_id": user_id, "data": db_data}
|
||||||
|
elif (
|
||||||
|
table_name is not None
|
||||||
|
and table_name == "key"
|
||||||
|
and query_type == "update_many"
|
||||||
|
and data_list is not None
|
||||||
|
and isinstance(data_list, list)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch write update queries
|
||||||
|
"""
|
||||||
|
batcher = self.db.batch_()
|
||||||
|
for idx, t in enumerate(data_list):
|
||||||
|
# check if plain text or hash
|
||||||
|
if t.token.startswith("sk-"): # type: ignore
|
||||||
|
t.token = self.hash_token(token=t.token) # type: ignore
|
||||||
|
try:
|
||||||
|
data_json = self.jsonify_object(data=t.model_dump())
|
||||||
|
except:
|
||||||
|
data_json = self.jsonify_object(data=t.dict())
|
||||||
|
batcher.litellm_verificationtoken.update(
|
||||||
|
where={"token": t.token}, # type: ignore
|
||||||
|
data={**data_json}, # type: ignore
|
||||||
|
)
|
||||||
|
await batcher.commit()
|
||||||
|
print_verbose(
|
||||||
|
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
|
@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
payload[param] = str(payload[param])
|
payload[param] = str(payload[param])
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _duration_in_seconds(duration: str):
|
||||||
|
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||||
|
if not match:
|
||||||
|
raise ValueError("Invalid duration format")
|
||||||
|
|
||||||
|
value, unit = match.groups()
|
||||||
|
value = int(value)
|
||||||
|
|
||||||
|
if unit == "s":
|
||||||
|
return value
|
||||||
|
elif unit == "m":
|
||||||
|
return value * 60
|
||||||
|
elif unit == "h":
|
||||||
|
return value * 3600
|
||||||
|
elif unit == "d":
|
||||||
|
return value * 86400
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported duration unit")
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_budget(prisma_client: PrismaClient):
|
||||||
|
"""
|
||||||
|
Gets all the non-expired keys for a db, which need budget to be reset
|
||||||
|
|
||||||
|
Resets their budget
|
||||||
|
|
||||||
|
Updates db
|
||||||
|
"""
|
||||||
|
if prisma_client is not None:
|
||||||
|
now = datetime.utcnow()
|
||||||
|
keys_to_reset = await prisma_client.get_data(
|
||||||
|
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in keys_to_reset:
|
||||||
|
key.spend = 0.0
|
||||||
|
duration_s = _duration_in_seconds(duration=key.budget_duration)
|
||||||
|
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
|
||||||
|
|
||||||
|
if len(keys_to_reset) > 0:
|
||||||
|
await prisma_client.update_data(
|
||||||
|
query_type="update_many", data_list=keys_to_reset, table_name="key"
|
||||||
|
)
|
||||||
|
|
|
@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
|
||||||
tpm_limit BigInt?
|
tpm_limit BigInt?
|
||||||
rpm_limit BigInt?
|
rpm_limit BigInt?
|
||||||
max_budget Float? @default(0.0)
|
max_budget Float? @default(0.0)
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
}
|
}
|
||||||
|
|
||||||
model LiteLLM_Config {
|
model LiteLLM_Config {
|
||||||
|
@ -43,8 +45,8 @@ model LiteLLM_Config {
|
||||||
|
|
||||||
model LiteLLM_SpendLogs {
|
model LiteLLM_SpendLogs {
|
||||||
request_id String @unique
|
request_id String @unique
|
||||||
api_key String @default ("")
|
|
||||||
call_type String
|
call_type String
|
||||||
|
api_key String @default ("")
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
startTime DateTime // Assuming start_time is a DateTime field
|
startTime DateTime // Assuming start_time is a DateTime field
|
||||||
endTime DateTime // Assuming end_time is a DateTime field
|
endTime DateTime // Assuming end_time is a DateTime field
|
||||||
|
@ -56,4 +58,4 @@ model LiteLLM_SpendLogs {
|
||||||
usage Json @default("{}")
|
usage Json @default("{}")
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
cache_hit String @default("")
|
cache_hit String @default("")
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue