From 7a63cbe8d0dea19e831338a52764a1286d1cdc6f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 28 Mar 2025 17:24:54 -0700 Subject: [PATCH] redis update buffer queue --- litellm/proxy/db/db_spend_update_writer.py | 2 +- litellm/proxy/db/redis_update_buffer.py | 240 +++++++++++++-------- 2 files changed, 152 insertions(+), 90 deletions(-) diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index fea40e58e7..55b95d16bf 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -436,7 +436,7 @@ class DBSpendUpdateWriter: try: db_spend_update_transactions = ( - await self.redis_update_buffer.get_all_update_transactions_from_redis() + await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer() ) if db_spend_update_transactions is not None: await DBSpendUpdateWriter._commit_spend_updates_to_db( diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index 6f334e90ed..ab43dc575d 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -4,10 +4,12 @@ Handles buffering database `UPDATE` transactions in Redis before committing them This is to prevent deadlocks and improve reliability """ -from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from litellm._logging import verbose_proxy_logger from litellm.caching import RedisCache +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import DBSpendUpdateTransactions from litellm.secret_managers.main import str_to_bool @@ -16,6 +18,9 @@ if TYPE_CHECKING: else: PrismaClient = Any +REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer" +MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100 + class RedisUpdateBuffer: """ @@ -61,14 +66,21 @@ class RedisUpdateBuffer: - value is the spend amount ``` - { - "0929880201": 10, - "0929880202": 20, - "0929880203": 30, - } + Redis List: + key_list_transactions: + [ + "0929880201": 1.2, + "0929880202": 0.01, + "0929880203": 0.001, + ] ``` """ - IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = ( + if self.redis_cache is None: + verbose_proxy_logger.debug( + "redis_cache is None, skipping store_in_memory_spend_updates_in_redis" + ) + return + db_spend_update_transactions: DBSpendUpdateTransactions = ( DBSpendUpdateTransactions( user_list_transactions=prisma_client.user_list_transactions, end_user_list_transactions=prisma_client.end_user_list_transactions, @@ -78,30 +90,47 @@ class RedisUpdateBuffer: org_list_transactions=prisma_client.org_list_transactions, ) ) - for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items(): - await self.increment_all_transaction_objects_in_redis( - key=key, - transactions=cast(Dict, _transactions), - ) - async def increment_all_transaction_objects_in_redis( - self, - key: str, - transactions: Dict, + # only store in redis if there are any updates to commit + if ( + self._number_of_transactions_to_store_in_redis(db_spend_update_transactions) + == 0 + ): + return + + list_of_transactions = [safe_dumps(db_spend_update_transactions)] + await self.redis_cache.async_rpush( + key=REDIS_UPDATE_BUFFER_KEY, + values=list_of_transactions, + ) + self._clear_all_in_memory_spend_updates(prisma_client) + + @staticmethod + def _number_of_transactions_to_store_in_redis( + db_spend_update_transactions: DBSpendUpdateTransactions, + ) -> int: + """ + Gets the number of transactions to store in Redis + """ + num_transactions = 0 + for v in db_spend_update_transactions.values(): + if isinstance(v, dict): + num_transactions += len(v) + return num_transactions + + @staticmethod + def _clear_all_in_memory_spend_updates( + prisma_client: PrismaClient, ): """ - Increments all transaction objects in Redis + Clears all in-memory spend updates """ - if self.redis_cache is None: - verbose_proxy_logger.debug( - "redis_cache is None, skipping increment_all_transaction_objects_in_redis" - ) - return - for transaction_id, transaction_amount in transactions.items(): - await self.redis_cache.async_increment( - key=f"{key}:{transaction_id}", - value=transaction_amount, - ) + prisma_client.user_list_transactions = {} + prisma_client.end_user_list_transactions = {} + prisma_client.key_list_transactions = {} + prisma_client.team_list_transactions = {} + prisma_client.team_member_list_transactions = {} + prisma_client.org_list_transactions = {} @staticmethod def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: @@ -110,77 +139,110 @@ class RedisUpdateBuffer: """ return {key.replace(prefix, "", 1): value for key, value in data.items()} - async def get_all_update_transactions_from_redis( + async def get_all_update_transactions_from_redis_buffer( self, ) -> Optional[DBSpendUpdateTransactions]: """ Gets all the update transactions from Redis + + On Redis we store a list of transactions as a JSON string + + eg. + [ + DBSpendUpdateTransactions( + user_list_transactions={ + "user_id_1": 1.2, + "user_id_2": 0.01, + }, + end_user_list_transactions={}, + key_list_transactions={ + "0929880201": 1.2, + "0929880202": 0.01, + }, + team_list_transactions={}, + team_member_list_transactions={}, + org_list_transactions={}, + ), + DBSpendUpdateTransactions( + user_list_transactions={ + "user_id_3": 1.2, + "user_id_4": 0.01, + }, + end_user_list_transactions={}, + key_list_transactions={ + "key_id_1": 1.2, + "key_id_2": 0.01, + }, + team_list_transactions={}, + team_member_list_transactions={}, + org_list_transactions={}, + ] """ if self.redis_cache is None: return None - user_transaction_keys = await self.redis_cache.async_scan_iter( - "user_list_transactions:*" + list_of_transactions = await self.redis_cache.async_lpop( + key=REDIS_UPDATE_BUFFER_KEY, + count=MAX_REDIS_BUFFER_DEQUEUE_COUNT, ) - end_user_transaction_keys = await self.redis_cache.async_scan_iter( - "end_user_list_transactions:*" - ) - key_transaction_keys = await self.redis_cache.async_scan_iter( - "key_list_transactions:*" - ) - team_transaction_keys = await self.redis_cache.async_scan_iter( - "team_list_transactions:*" - ) - team_member_transaction_keys = await self.redis_cache.async_scan_iter( - "team_member_list_transactions:*" - ) - org_transaction_keys = await self.redis_cache.async_scan_iter( - "org_list_transactions:*" + if list_of_transactions is None: + return None + + # Parse the list of transactions from JSON strings + parsed_transactions = self._parse_list_of_transactions(list_of_transactions) + + # If there are no transactions, return None + if len(parsed_transactions) == 0: + return None + + # Combine all transactions into a single transaction + combined_transaction = self._combine_list_of_transactions(parsed_transactions) + + return combined_transaction + + @staticmethod + def _parse_list_of_transactions( + list_of_transactions: List[str], + ) -> List[DBSpendUpdateTransactions]: + """ + Parses the list of transactions from Redis + """ + return [json.loads(transaction) for transaction in list_of_transactions] + + @staticmethod + def _combine_list_of_transactions( + list_of_transactions: List[DBSpendUpdateTransactions], + ) -> DBSpendUpdateTransactions: + """ + Combines the list of transactions into a single DBSpendUpdateTransactions object + """ + # Initialize a new combined transaction object with empty dictionaries + combined_transaction = DBSpendUpdateTransactions( + user_list_transactions={}, + end_user_list_transactions={}, + key_list_transactions={}, + team_list_transactions={}, + team_member_list_transactions={}, + org_list_transactions={}, ) - user_list_transactions = await self.redis_cache.async_batch_get_cache( - user_transaction_keys - ) - end_user_list_transactions = await self.redis_cache.async_batch_get_cache( - end_user_transaction_keys - ) - key_list_transactions = await self.redis_cache.async_batch_get_cache( - key_transaction_keys - ) - team_list_transactions = await self.redis_cache.async_batch_get_cache( - team_transaction_keys - ) - team_member_list_transactions = await self.redis_cache.async_batch_get_cache( - team_member_transaction_keys - ) - org_list_transactions = await self.redis_cache.async_batch_get_cache( - org_transaction_keys - ) + # Define the transaction fields to process + transaction_fields = [ + "user_list_transactions", + "end_user_list_transactions", + "key_list_transactions", + "team_list_transactions", + "team_member_list_transactions", + "org_list_transactions", + ] - # filter out the "prefix" from the keys using the helper method - user_list_transactions = self._remove_prefix_from_keys( - user_list_transactions, "user_list_transactions:" - ) - end_user_list_transactions = self._remove_prefix_from_keys( - end_user_list_transactions, "end_user_list_transactions:" - ) - key_list_transactions = self._remove_prefix_from_keys( - key_list_transactions, "key_list_transactions:" - ) - team_list_transactions = self._remove_prefix_from_keys( - team_list_transactions, "team_list_transactions:" - ) - team_member_list_transactions = self._remove_prefix_from_keys( - team_member_list_transactions, "team_member_list_transactions:" - ) - org_list_transactions = self._remove_prefix_from_keys( - org_list_transactions, "org_list_transactions:" - ) + # Loop through each transaction and combine the values + for transaction in list_of_transactions: + # Process each field type + for field in transaction_fields: + if transaction.get(field): + for entity_id, amount in transaction[field].items(): + combined_transaction[field][entity_id] = ( + combined_transaction[field].get(entity_id, 0) + amount + ) - return DBSpendUpdateTransactions( - user_list_transactions=user_list_transactions, - end_user_list_transactions=end_user_list_transactions, - key_list_transactions=key_list_transactions, - team_list_transactions=team_list_transactions, - team_member_list_transactions=team_member_list_transactions, - org_list_transactions=org_list_transactions, - ) + return combined_transaction