diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1bbb050f43..1db7150f0e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1702,6 +1702,19 @@ async def update_database( response_cost + prisma_client.team_list_transactons.get(team_id, 0) ) + + try: + # Track spend of the team member within this team + # key is "team_id::::user_id::" + team_member_key = f"team_id::{team_id}::user_id::{user_id}" + prisma_client.team_member_list_transactons[team_member_key] = ( + response_cost + + prisma_client.team_member_list_transactons.get( + team_member_key, 0 + ) + ) + except: + pass except Exception as e: verbose_proxy_logger.info( f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ffff87c51f..0742c21094 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -551,6 +551,7 @@ class PrismaClient: end_user_list_transactons: dict = {} key_list_transactons: dict = {} team_list_transactons: dict = {} + team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"] org_list_transactons: dict = {} spend_log_transactions: List = [] @@ -2257,6 +2258,56 @@ async def update_spend( ) raise e + ### UPDATE TEAM Membership TABLE with spend ### + if len(prisma_client.team_member_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + key, + response_cost, + ) in prisma_client.team_member_list_transactons.items(): + # key is "team_id::::user_id::" + team_id = key.split("::")[1] + user_id = key.split("::")[3] + + batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id, "user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_member_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update team spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="update_spend", + traceback_str=error_traceback, + ) + ) + raise e + ### UPDATE ORG TABLE ### if len(prisma_client.org_list_transactons.keys()) > 0: for i in range(n_retry_times + 1):