Merge pull request #1600 from BerriAI/litellm_global_budget

feat(proxy_server.py): support global budget and resets
This commit is contained in:
Krish Dholakia 2024-01-24 14:55:36 -08:00 committed by GitHub
commit 3e7ed4082a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 120 additions and 18 deletions

View file

@ -425,12 +425,21 @@ class PrismaClient:
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication Error: invalid user key - token does not exist",
)
elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
"user_id": user_id,
}
)
elif user_id is not None or (
table_name is not None and table_name == "user"
):
if query_type == "find_unique":
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
"user_id": user_id, # type: ignore
}
)
elif query_type == "find_all" and reset_at is not None:
response = await self.db.litellm_usertable.find_many(
where={ # type:ignore
"budget_reset_at": {"lt": reset_at},
}
)
return response
elif table_name == "user" and query_type == "find_all":
response = await self.db.litellm_usertable.find_many( # type: ignore
@ -597,10 +606,16 @@ class PrismaClient:
+ "\033[0m"
)
return {"token": token, "data": db_data}
elif user_id is not None:
elif (
user_id is not None
or (table_name is not None and table_name == "user")
and query_type == "update"
):
"""
If data['spend'] + data['user'], update the user table with spend info as well
"""
if user_id is None:
user_id = db_data["user_id"]
update_user_row = await self.db.litellm_usertable.update(
where={"user_id": user_id}, # type: ignore
data={**db_data}, # type: ignore
@ -650,6 +665,30 @@ class PrismaClient:
print_verbose(
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
)
elif (
table_name is not None
and table_name == "user"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for idx, user in enumerate(data_list):
try:
data_json = self.jsonify_object(data=user.model_dump())
except:
data_json = self.jsonify_object(data=user.dict())
batcher.litellm_usertable.update(
where={"user_id": user.user_id}, # type: ignore
data={**data_json}, # type: ignore
)
await batcher.commit()
print_verbose(
"\033[91m" + f"DB User Table update succeeded" + "\033[0m"
)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -1007,17 +1046,36 @@ async def reset_budget(prisma_client: PrismaClient):
Updates db
"""
if prisma_client is not None:
### RESET KEY BUDGET ###
now = datetime.utcnow()
keys_to_reset = await prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = now + timedelta(seconds=duration_s)
if len(keys_to_reset) > 0:
await prisma_client.update_data(
query_type="update_many", data_list=keys_to_reset, table_name="key"
)
### RESET USER BUDGET ###
now = datetime.utcnow()
users_to_reset = await prisma_client.get_data(
table_name="user", query_type="find_all", reset_at=now
)
verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}")
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
user.spend = 0.0
duration_s = _duration_in_seconds(duration=user.budget_duration)
user.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
query_type="update_many", data_list=users_to_reset, table_name="user"
)