fix(proxy/utils.py): move to batch writing db updates

This commit is contained in:
Krrish Dholakia 2024-03-16 22:32:00 -07:00
parent 710efab0de
commit 077b9c6234
4 changed files with 269 additions and 95 deletions

View file

@ -472,6 +472,11 @@ def on_backoff(details):
class PrismaClient:
user_list_transactons: List = []
key_list_transactons: List = []
team_list_transactons: List = []
spend_log_transactons: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
@ -1841,6 +1846,65 @@ async def reset_budget(prisma_client: PrismaClient):
)
async def update_spend(
prisma_client: PrismaClient,
):
"""
Batch write updates to db.
Triggered every minute.
Requires:
user_id_list: list,
keys_list: list,
team_list: list,
spend_logs: list,
"""
n_retry_times = 3
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons) > 0:
for i in range(n_retry_times + 1):
try:
remaining_transactions = list(prisma_client.user_list_transactons)
while remaining_transactions:
batch_size = min(5000, len(remaining_transactions))
batch_transactions = remaining_transactions[:batch_size]
async with prisma_client.db.tx(timeout=60000) as transaction:
async with transaction.batch_() as batcher:
for user_id_tuple in batch_transactions:
user_id, response_cost = user_id_tuple
if user_id != "litellm-proxy-budget":
batcher.litellm_usertable.update(
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
remaining_transactions = remaining_transactions[batch_size:]
prisma_client.user_list_transactons = (
[]
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE KEY TABLE ###
### UPDATE TEAM TABLE ###
### UPDATE SPEND LOGS TABLE ###
async def monitor_spend_list(prisma_client: PrismaClient):
"""
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
"""
if len(prisma_client.user_list_transactons) > 10000:
await update_spend(prisma_client=prisma_client)
async def _read_request_body(request):
"""
Asynchronous function to read the request body and parse it as JSON or literal data.