feat - update team+user_id spend

This commit is contained in:
Ishaan Jaff 2024-05-22 17:49:54 -07:00
parent f548334e8b
commit c2d25b9a14
2 changed files with 64 additions and 0 deletions

View file

@ -551,6 +551,7 @@ class PrismaClient:
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"]
org_list_transactons: dict = {}
spend_log_transactions: List = []
@ -2257,6 +2258,56 @@ async def update_spend(
)
raise e
### UPDATE TEAM Membership TABLE with spend ###
if len(prisma_client.team_member_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
key,
response_cost,
) in prisma_client.team_member_list_transactons.items():
# key is "team_id::<value>::user_id::<value>"
team_id = key.split("::")[1]
user_id = key.split("::")[3]
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
where={"team_id": team_id, "user_id": user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.team_member_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):