mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(proxy/utils.py): enable background process to reset key budgets
This commit is contained in:
parent
43ea1e589f
commit
e6f05e858b
5 changed files with 138 additions and 18 deletions
|
@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from fastapi import HTTPException, status
|
||||
import smtplib
|
||||
import smtplib, re
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -363,6 +363,8 @@ class PrismaClient:
|
|||
user_id: Optional[str] = None,
|
||||
table_name: Optional[Literal["user", "key", "config"]] = None,
|
||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||
expires: Optional[datetime] = None,
|
||||
reset_at: Optional[datetime] = None,
|
||||
):
|
||||
try:
|
||||
print_verbose("PrismaClient: get_data")
|
||||
|
@ -391,6 +393,24 @@ class PrismaClient:
|
|||
for r in response:
|
||||
if isinstance(r.expires, datetime):
|
||||
r.expires = r.expires.isoformat()
|
||||
elif (
|
||||
query_type == "find_all"
|
||||
and expires is not None
|
||||
and reset_at is not None
|
||||
):
|
||||
response = await self.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"expires": None},
|
||||
{"expires": {"gt": expires}},
|
||||
],
|
||||
"budget_reset_at": {"lt": reset_at},
|
||||
}
|
||||
)
|
||||
if response is not None and len(response) > 0:
|
||||
for r in response:
|
||||
if isinstance(r.expires, datetime):
|
||||
r.expires = r.expires.isoformat()
|
||||
print_verbose(f"PrismaClient: response={response}")
|
||||
if response is not None:
|
||||
return response
|
||||
|
@ -517,7 +537,10 @@ class PrismaClient:
|
|||
self,
|
||||
token: Optional[str] = None,
|
||||
data: dict = {},
|
||||
data_list: Optional[List] = None,
|
||||
user_id: Optional[str] = None,
|
||||
query_type: Literal["update", "update_many"] = "update",
|
||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||
):
|
||||
"""
|
||||
Update existing data
|
||||
|
@ -526,20 +549,21 @@ class PrismaClient:
|
|||
db_data = self.jsonify_object(data=data)
|
||||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose(
|
||||
"\033[91m"
|
||||
+ f"DB Token Table update succeeded {response}"
|
||||
+ "\033[0m"
|
||||
)
|
||||
return {"token": token, "data": db_data}
|
||||
if query_type == "update":
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose(
|
||||
"\033[91m"
|
||||
+ f"DB Token Table update succeeded {response}"
|
||||
+ "\033[0m"
|
||||
)
|
||||
return {"token": token, "data": db_data}
|
||||
elif user_id is not None:
|
||||
"""
|
||||
If data['spend'] + data['user'], update the user table with spend info as well
|
||||
|
@ -566,6 +590,33 @@ class PrismaClient:
|
|||
+ "\033[0m"
|
||||
)
|
||||
return {"user_id": user_id, "data": db_data}
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "key"
|
||||
and query_type == "update_many"
|
||||
and data_list is not None
|
||||
and isinstance(data_list, list)
|
||||
):
|
||||
"""
|
||||
Batch write update queries
|
||||
"""
|
||||
batcher = self.db.batch_()
|
||||
for idx, t in enumerate(data_list):
|
||||
# check if plain text or hash
|
||||
if t.token.startswith("sk-"): # type: ignore
|
||||
t.token = self.hash_token(token=t.token) # type: ignore
|
||||
try:
|
||||
data_json = self.jsonify_object(data=t.model_dump())
|
||||
except:
|
||||
data_json = self.jsonify_object(data=t.dict())
|
||||
batcher.litellm_verificationtoken.update(
|
||||
where={"token": t.token}, # type: ignore
|
||||
data={**data_json}, # type: ignore
|
||||
)
|
||||
await batcher.commit()
|
||||
print_verbose(
|
||||
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
|
||||
)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
@ -886,3 +937,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
payload[param] = str(payload[param])
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
|
||||
async def reset_budget(prisma_client: PrismaClient):
|
||||
"""
|
||||
Gets all the non-expired keys for a db, which need budget to be reset
|
||||
|
||||
Resets their budget
|
||||
|
||||
Updates db
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
now = datetime.utcnow()
|
||||
keys_to_reset = await prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
|
||||
for key in keys_to_reset:
|
||||
key.spend = 0.0
|
||||
duration_s = _duration_in_seconds(duration=key.budget_duration)
|
||||
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
|
||||
|
||||
if len(keys_to_reset) > 0:
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=keys_to_reset, table_name="key"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue