fix(proxy/utils.py): batch writing updates to db

This commit is contained in:
Krrish Dholakia 2024-03-18 16:47:02 -07:00
parent 1618751824
commit 1b10123f07
3 changed files with 141 additions and 217 deletions

View file

@ -7,6 +7,10 @@ from litellm.proxy._types import (
LiteLLM_VerificationToken,
LiteLLM_VerificationTokenView,
LiteLLM_SpendLogs,
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
Member,
)
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import (
@ -472,9 +476,10 @@ def on_backoff(details):
class PrismaClient:
user_list_transactons: List = []
key_list_transactons: List = []
team_list_transactons: List = []
user_list_transactons: dict = {}
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
spend_log_transactons: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
@ -1855,34 +1860,62 @@ async def update_spend(
Triggered every minute.
Requires:
user_id_list: list,
user_id_list: dict,
keys_list: list,
team_list: list,
spend_logs: list,
"""
verbose_proxy_logger.debug(
f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}"
)
n_retry_times = 3
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons) > 0:
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
remaining_transactions = list(prisma_client.user_list_transactons)
while remaining_transactions:
batch_size = min(5000, len(remaining_transactions))
batch_transactions = remaining_transactions[:batch_size]
async with prisma_client.db.tx(timeout=60000) as transaction:
async with transaction.batch_() as batcher:
for user_id_tuple in batch_transactions:
user_id, response_cost = user_id_tuple
if user_id != "litellm-proxy-budget":
batcher.litellm_usertable.update(
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
remaining_transactions = remaining_transactions[batch_size:]
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
user_id,
response_cost,
) in prisma_client.user_list_transactons.items():
batcher.litellm_usertable.update_many(
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.user_list_transactons = (
[]
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
new_user_obj = LiteLLM_EndUserTable(
user_id=end_user_id, spend=response_cost, blocked=False
)
batcher.litellm_endusertable.update_many(
where={"user_id": end_user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.end_user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
@ -1893,7 +1926,55 @@ async def update_spend(
raise e
### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
token,
response_cost,
) in prisma_client.key_list_transactons.items():
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
where={"token": token},
data={"spend": {"increment": response_cost}},
)
prisma_client.key_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE TEAM TABLE ###
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
team_id,
response_cost,
) in prisma_client.team_list_transactons.items():
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
where={"team_id": team_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.team_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE SPEND LOGS TABLE ###