feat(proxy/utils.py): enable background process to reset key budgets

This commit is contained in:
Krrish Dholakia 2024-01-23 12:33:13 -08:00
parent 01a2514b98
commit b1a105e309
5 changed files with 138 additions and 18 deletions

View file

@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status
import smtplib
import smtplib, re
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime
from datetime import datetime, timedelta
def print_verbose(print_statement):
@ -363,6 +363,8 @@ class PrismaClient:
user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None,
):
try:
print_verbose("PrismaClient: get_data")
@ -391,6 +393,24 @@ class PrismaClient:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif (
query_type == "find_all"
and expires is not None
and reset_at is not None
):
response = await self.db.litellm_verificationtoken.find_many(
where={
"OR": [
{"expires": None},
{"expires": {"gt": expires}},
],
"budget_reset_at": {"lt": reset_at},
}
)
if response is not None and len(response) > 0:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
print_verbose(f"PrismaClient: response={response}")
if response is not None:
return response
@ -517,7 +537,10 @@ class PrismaClient:
self,
token: Optional[str] = None,
data: dict = {},
data_list: Optional[List] = None,
user_id: Optional[str] = None,
query_type: Literal["update", "update_many"] = "update",
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
):
"""
Update existing data
@ -526,20 +549,21 @@ class PrismaClient:
db_data = self.jsonify_object(data=data)
if token is not None:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose(
"\033[91m"
+ f"DB Token Table update succeeded {response}"
+ "\033[0m"
)
return {"token": token, "data": db_data}
if query_type == "update":
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose(
"\033[91m"
+ f"DB Token Table update succeeded {response}"
+ "\033[0m"
)
return {"token": token, "data": db_data}
elif user_id is not None:
"""
If data['spend'] + data['user'], update the user table with spend info as well
@ -566,6 +590,33 @@ class PrismaClient:
+ "\033[0m"
)
return {"user_id": user_id, "data": db_data}
elif (
table_name is not None
and table_name == "key"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for idx, t in enumerate(data_list):
# check if plain text or hash
if t.token.startswith("sk-"): # type: ignore
t.token = self.hash_token(token=t.token) # type: ignore
try:
data_json = self.jsonify_object(data=t.model_dump())
except:
data_json = self.jsonify_object(data=t.dict())
batcher.litellm_verificationtoken.update(
where={"token": t.token}, # type: ignore
data={**data_json}, # type: ignore
)
await batcher.commit()
print_verbose(
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
payload[param] = str(payload[param])
return payload
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need budget to be reset
Resets their budget
Updates db
"""
if prisma_client is not None:
now = datetime.utcnow()
keys_to_reset = await prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
if len(keys_to_reset) > 0:
await prisma_client.update_data(
query_type="update_many", data_list=keys_to_reset, table_name="key"
)