""" Handles buffering database `UPDATE` transactions in Redis before committing them to the database This is to prevent deadlocks and improve reliability """ import json from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from litellm._logging import verbose_proxy_logger from litellm.caching import RedisCache from litellm.constants import MAX_REDIS_BUFFER_DEQUEUE_COUNT, REDIS_UPDATE_BUFFER_KEY from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import DBSpendUpdateTransactions from litellm.proxy.db.spend_update_queue import SpendUpdateQueue from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient else: PrismaClient = Any class RedisUpdateBuffer: """ Handles buffering database `UPDATE` transactions in Redis before committing them to the database This is to prevent deadlocks and improve reliability """ def __init__( self, redis_cache: Optional[RedisCache] = None, ): self.redis_cache = redis_cache @staticmethod def _should_commit_spend_updates_to_redis() -> bool: """ Checks if the Pod should commit spend updates to Redis This setting enables buffering database transactions in Redis to improve reliability and reduce database contention """ from litellm.proxy.proxy_server import general_settings _use_redis_transaction_buffer: Optional[ Union[bool, str] ] = general_settings.get("use_redis_transaction_buffer", False) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) if _use_redis_transaction_buffer is None: return False return _use_redis_transaction_buffer async def store_in_memory_spend_updates_in_redis( self, spend_update_queue: SpendUpdateQueue, ): """ Stores the in-memory spend updates to Redis Each transaction is a dict stored as following: - key is the entity id - value is the spend amount ``` Redis List: key_list_transactions: [ "0929880201": 1.2, "0929880202": 0.01, "0929880203": 0.001, ] ``` """ if self.redis_cache is None: verbose_proxy_logger.debug( "redis_cache is None, skipping store_in_memory_spend_updates_in_redis" ) return db_spend_update_transactions = ( await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() ) verbose_proxy_logger.debug( "ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions ) # only store in redis if there are any updates to commit if ( self._number_of_transactions_to_store_in_redis(db_spend_update_transactions) == 0 ): return list_of_transactions = [safe_dumps(db_spend_update_transactions)] await self.redis_cache.async_rpush( key=REDIS_UPDATE_BUFFER_KEY, values=list_of_transactions, ) @staticmethod def _number_of_transactions_to_store_in_redis( db_spend_update_transactions: DBSpendUpdateTransactions, ) -> int: """ Gets the number of transactions to store in Redis """ num_transactions = 0 for v in db_spend_update_transactions.values(): if isinstance(v, dict): num_transactions += len(v) return num_transactions @staticmethod def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: """ Removes the specified prefix from the keys of a dictionary. """ return {key.replace(prefix, "", 1): value for key, value in data.items()} async def get_all_update_transactions_from_redis_buffer( self, ) -> Optional[DBSpendUpdateTransactions]: """ Gets all the update transactions from Redis On Redis we store a list of transactions as a JSON string eg. [ DBSpendUpdateTransactions( user_list_transactions={ "user_id_1": 1.2, "user_id_2": 0.01, }, end_user_list_transactions={}, key_list_transactions={ "0929880201": 1.2, "0929880202": 0.01, }, team_list_transactions={}, team_member_list_transactions={}, org_list_transactions={}, ), DBSpendUpdateTransactions( user_list_transactions={ "user_id_3": 1.2, "user_id_4": 0.01, }, end_user_list_transactions={}, key_list_transactions={ "key_id_1": 1.2, "key_id_2": 0.01, }, team_list_transactions={}, team_member_list_transactions={}, org_list_transactions={}, ] """ if self.redis_cache is None: return None list_of_transactions = await self.redis_cache.async_lpop( key=REDIS_UPDATE_BUFFER_KEY, count=MAX_REDIS_BUFFER_DEQUEUE_COUNT, ) if list_of_transactions is None: return None # Parse the list of transactions from JSON strings parsed_transactions = self._parse_list_of_transactions(list_of_transactions) # If there are no transactions, return None if len(parsed_transactions) == 0: return None # Combine all transactions into a single transaction combined_transaction = self._combine_list_of_transactions(parsed_transactions) return combined_transaction @staticmethod def _parse_list_of_transactions( list_of_transactions: Union[Any, List[Any]], ) -> List[DBSpendUpdateTransactions]: """ Parses the list of transactions from Redis """ if isinstance(list_of_transactions, list): return [json.loads(transaction) for transaction in list_of_transactions] else: return [json.loads(list_of_transactions)] @staticmethod def _combine_list_of_transactions( list_of_transactions: List[DBSpendUpdateTransactions], ) -> DBSpendUpdateTransactions: """ Combines the list of transactions into a single DBSpendUpdateTransactions object """ # Initialize a new combined transaction object with empty dictionaries combined_transaction = DBSpendUpdateTransactions( user_list_transactions={}, end_user_list_transactions={}, key_list_transactions={}, team_list_transactions={}, team_member_list_transactions={}, org_list_transactions={}, ) # Define the transaction fields to process transaction_fields = [ "user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions", ] # Loop through each transaction and combine the values for transaction in list_of_transactions: # Process each field type for field in transaction_fields: if transaction.get(field): for entity_id, amount in transaction[field].items(): # type: ignore combined_transaction[field][entity_id] = ( # type: ignore combined_transaction[field].get(entity_id, 0) + amount # type: ignore ) return combined_transaction