diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 7184262ee5..eee598f79a 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -47,14 +47,9 @@ class DBSpendUpdateWriter: self, redis_cache: Optional[RedisCache] = None, ): - from litellm.proxy.proxy_server import prisma_client - self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) - self.pod_leader_manager = PodLockManager( - cronjob_id=DB_SPEND_UPDATE_JOB_NAME, - prisma_client=prisma_client, - ) + self.pod_leader_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) @staticmethod async def update_database( @@ -411,17 +406,23 @@ class DBSpendUpdateWriter: # Only commit from redis to db if this pod is the leader if await self.pod_leader_manager.acquire_lock(): - db_spend_update_transactions = ( - await self.redis_update_buffer.get_all_update_transactions_from_redis() - ) - if db_spend_update_transactions is not None: - await DBSpendUpdateWriter._commit_spend_updates_to_db( - prisma_client=prisma_client, - n_retry_times=n_retry_times, - proxy_logging_obj=proxy_logging_obj, - db_spend_update_transactions=db_spend_update_transactions, + verbose_proxy_logger.debug("acquired lock for spend updates") + + try: + db_spend_update_transactions = ( + await self.redis_update_buffer.get_all_update_transactions_from_redis() ) - await self.pod_leader_manager.release_lock() + if db_spend_update_transactions is not None: + await DBSpendUpdateWriter._commit_spend_updates_to_db( + prisma_client=prisma_client, + n_retry_times=n_retry_times, + proxy_logging_obj=proxy_logging_obj, + db_spend_update_transactions=db_spend_update_transactions, + ) + except Exception as e: + verbose_proxy_logger.error(f"Error committing spend updates: {e}") + finally: + await self.pod_leader_manager.release_lock() else: db_spend_update_transactions = DBSpendUpdateTransactions( user_list_transactions=prisma_client.user_list_transactions, @@ -456,7 +457,10 @@ class DBSpendUpdateWriter: ### UPDATE USER TABLE ### user_list_transactions = db_spend_update_transactions["user_list_transactions"] - if len(user_list_transactions.keys()) > 0: + if ( + user_list_transactions is not None + and len(user_list_transactions.keys()) > 0 + ): for i in range(n_retry_times + 1): start_time = time.time() try: @@ -501,7 +505,10 @@ class DBSpendUpdateWriter: end_user_list_transactions = db_spend_update_transactions[ "end_user_list_transactions" ] - if len(end_user_list_transactions.keys()) > 0: + if ( + end_user_list_transactions is not None + and len(end_user_list_transactions.keys()) > 0 + ): await ProxyUpdateSpend.update_end_user_spend( n_retry_times=n_retry_times, prisma_client=prisma_client, @@ -510,9 +517,9 @@ class DBSpendUpdateWriter: ### UPDATE KEY TABLE ### key_list_transactions = db_spend_update_transactions["key_list_transactions"] verbose_proxy_logger.debug( - "KEY Spend transactions: {}".format(len(key_list_transactions.keys())) + "KEY Spend transactions: {}".format(key_list_transactions) ) - if len(key_list_transactions.keys()) > 0: + if key_list_transactions is not None and len(key_list_transactions.keys()) > 0: for i in range(n_retry_times + 1): start_time = time.time() try: @@ -555,7 +562,10 @@ class DBSpendUpdateWriter: ) ) team_list_transactions = db_spend_update_transactions["team_list_transactions"] - if len(team_list_transactions.keys()) > 0: + if ( + team_list_transactions is not None + and len(team_list_transactions.keys()) > 0 + ): for i in range(n_retry_times + 1): start_time = time.time() try: @@ -600,7 +610,10 @@ class DBSpendUpdateWriter: team_member_list_transactions = db_spend_update_transactions[ "team_member_list_transactions" ] - if len(team_member_list_transactions.keys()) > 0: + if ( + team_member_list_transactions is not None + and len(team_member_list_transactions.keys()) > 0 + ): for i in range(n_retry_times + 1): start_time = time.time() try: @@ -642,7 +655,7 @@ class DBSpendUpdateWriter: ### UPDATE ORG TABLE ### org_list_transactions = db_spend_update_transactions["org_list_transactions"] - if len(org_list_transactions.keys()) > 0: + if org_list_transactions is not None and len(org_list_transactions.keys()) > 0: for i in range(n_retry_times + 1): start_time = time.time() try: diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index 22afc56483..94cd1e47c7 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -17,12 +17,12 @@ else: class DBSpendUpdateTransactions(TypedDict): - user_list_transactions: Dict[str, float] - end_user_list_transactions: Dict[str, float] - key_list_transactions: Dict[str, float] - team_list_transactions: Dict[str, float] - team_member_list_transactions: Dict[str, float] - org_list_transactions: Dict[str, float] + user_list_transactions: Optional[Dict[str, float]] + end_user_list_transactions: Optional[Dict[str, float]] + key_list_transactions: Optional[Dict[str, float]] + team_list_transactions: Optional[Dict[str, float]] + team_member_list_transactions: Optional[Dict[str, float]] + org_list_transactions: Optional[Dict[str, float]] class RedisUpdateBuffer: