diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 34e8eae173..cc45b6b96a 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -7,16 +7,28 @@ Module responsible for import asyncio import os +import time import traceback -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Optional, Union import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache -from litellm.proxy._types import Litellm_EntityType, LiteLLM_UserTable, SpendLogsPayload +from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, + Litellm_EntityType, + LiteLLM_UserTable, + SpendLogsPayload, +) from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload -from litellm.proxy.utils import PrismaClient, ProxyUpdateSpend, hash_token +from litellm.proxy.utils import ( + PrismaClient, + ProxyLogging, + ProxyUpdateSpend, + _raise_failed_update_spend_exception, + hash_token, +) class DBSpendUpdateWriter: @@ -346,3 +358,235 @@ class DBSpendUpdateWriter: payload.copy() ) return prisma_client + + @staticmethod + async def commit_update_transactions_to_db( # noqa: PLR0915 + prisma_client: PrismaClient, + n_retry_times: int, + proxy_logging_obj: ProxyLogging, + ): + """ + Handles commiting update spend transactions to db + + UPDATES can lead to deadlocks, hence we handle them separately + + Args: + prisma_client: PrismaClient object + n_retry_times: int, number of retry times + proxy_logging_obj: ProxyLogging object + """ + ### UPDATE USER TABLE ### + if len(prisma_client.user_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + user_id, + response_cost, + ) in prisma_client.user_list_transactons.items(): + batcher.litellm_usertable.update_many( + where={"user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.user_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if ( + i >= n_retry_times + ): # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE END-USER TABLE ### + verbose_proxy_logger.debug( + "End-User Spend transactions: {}".format( + len(prisma_client.end_user_list_transactons.keys()) + ) + ) + if len(prisma_client.end_user_list_transactons.keys()) > 0: + await ProxyUpdateSpend.update_end_user_spend( + n_retry_times=n_retry_times, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + ) + ### UPDATE KEY TABLE ### + verbose_proxy_logger.debug( + "KEY Spend transactions: {}".format( + len(prisma_client.key_list_transactons.keys()) + ) + ) + if len(prisma_client.key_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + token, + response_cost, + ) in prisma_client.key_list_transactons.items(): + batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists + where={"token": token}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.key_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if ( + i >= n_retry_times + ): # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE TEAM TABLE ### + verbose_proxy_logger.debug( + "Team Spend transactions: {}".format( + len(prisma_client.team_list_transactons.keys()) + ) + ) + if len(prisma_client.team_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + team_id, + response_cost, + ) in prisma_client.team_list_transactons.items(): + verbose_proxy_logger.debug( + "Updating spend for team id={} by {}".format( + team_id, response_cost + ) + ) + batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if ( + i >= n_retry_times + ): # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE TEAM Membership TABLE with spend ### + if len(prisma_client.team_member_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + key, + response_cost, + ) in prisma_client.team_member_list_transactons.items(): + # key is "team_id::::user_id::" + team_id = key.split("::")[1] + user_id = key.split("::")[3] + + batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id, "user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_member_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if ( + i >= n_retry_times + ): # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE ORG TABLE ### + if len(prisma_client.org_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + org_id, + response_cost, + ) in prisma_client.org_list_transactons.items(): + batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"organization_id": org_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.org_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if ( + i >= n_retry_times + ): # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7f1ac814a8..013ecf97c4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -62,6 +62,7 @@ from litellm.proxy.db.create_views import ( create_missing_views, should_create_missing_views, ) +from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter from litellm.proxy.db.log_db_metrics import log_db_metrics from litellm.proxy.db.prisma_client import PrismaWrapper from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck @@ -2674,202 +2675,11 @@ async def update_spend( # noqa: PLR0915 spend_logs: list, """ n_retry_times = 3 - i = None - ### UPDATE USER TABLE ### - if len(prisma_client.user_list_transactons.keys()) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - async with prisma_client.db.tx( - timeout=timedelta(seconds=60) - ) as transaction: - async with transaction.batch_() as batcher: - for ( - user_id, - response_cost, - ) in prisma_client.user_list_transactons.items(): - batcher.litellm_usertable.update_many( - where={"user_id": user_id}, - data={"spend": {"increment": response_cost}}, - ) - prisma_client.user_list_transactons = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. - break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - - ### UPDATE END-USER TABLE ### - verbose_proxy_logger.debug( - "End-User Spend transactions: {}".format( - len(prisma_client.end_user_list_transactons.keys()) - ) + await DBSpendUpdateWriter.commit_update_transactions_to_db( + prisma_client=prisma_client, + n_retry_times=n_retry_times, + proxy_logging_obj=proxy_logging_obj, ) - if len(prisma_client.end_user_list_transactons.keys()) > 0: - await ProxyUpdateSpend.update_end_user_spend( - n_retry_times=n_retry_times, - prisma_client=prisma_client, - proxy_logging_obj=proxy_logging_obj, - ) - ### UPDATE KEY TABLE ### - verbose_proxy_logger.debug( - "KEY Spend transactions: {}".format( - len(prisma_client.key_list_transactons.keys()) - ) - ) - if len(prisma_client.key_list_transactons.keys()) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - async with prisma_client.db.tx( - timeout=timedelta(seconds=60) - ) as transaction: - async with transaction.batch_() as batcher: - for ( - token, - response_cost, - ) in prisma_client.key_list_transactons.items(): - batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists - where={"token": token}, - data={"spend": {"increment": response_cost}}, - ) - prisma_client.key_list_transactons = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. - break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - - ### UPDATE TEAM TABLE ### - verbose_proxy_logger.debug( - "Team Spend transactions: {}".format( - len(prisma_client.team_list_transactons.keys()) - ) - ) - if len(prisma_client.team_list_transactons.keys()) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - async with prisma_client.db.tx( - timeout=timedelta(seconds=60) - ) as transaction: - async with transaction.batch_() as batcher: - for ( - team_id, - response_cost, - ) in prisma_client.team_list_transactons.items(): - verbose_proxy_logger.debug( - "Updating spend for team id={} by {}".format( - team_id, response_cost - ) - ) - batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists - where={"team_id": team_id}, - data={"spend": {"increment": response_cost}}, - ) - prisma_client.team_list_transactons = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. - break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - - ### UPDATE TEAM Membership TABLE with spend ### - if len(prisma_client.team_member_list_transactons.keys()) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - async with prisma_client.db.tx( - timeout=timedelta(seconds=60) - ) as transaction: - async with transaction.batch_() as batcher: - for ( - key, - response_cost, - ) in prisma_client.team_member_list_transactons.items(): - # key is "team_id::::user_id::" - team_id = key.split("::")[1] - user_id = key.split("::")[3] - - batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists - where={"team_id": team_id, "user_id": user_id}, - data={"spend": {"increment": response_cost}}, - ) - prisma_client.team_member_list_transactons = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. - break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - - ### UPDATE ORG TABLE ### - if len(prisma_client.org_list_transactons.keys()) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - async with prisma_client.db.tx( - timeout=timedelta(seconds=60) - ) as transaction: - async with transaction.batch_() as batcher: - for ( - org_id, - response_cost, - ) in prisma_client.org_list_transactons.items(): - batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists - where={"organization_id": org_id}, - data={"spend": {"increment": response_cost}}, - ) - prisma_client.org_list_transactons = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. - break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) ### UPDATE SPEND LOGS ### verbose_proxy_logger.debug(