From efe6d375e9adec23cc24f1934a41fc8e7e6fe298 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 31 Mar 2025 18:40:03 -0700 Subject: [PATCH] add new SpendUpdateQueue --- litellm/proxy/db/spend_update_queue.py | 54 ++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 litellm/proxy/db/spend_update_queue.py diff --git a/litellm/proxy/db/spend_update_queue.py b/litellm/proxy/db/spend_update_queue.py new file mode 100644 index 0000000000..778361be84 --- /dev/null +++ b/litellm/proxy/db/spend_update_queue.py @@ -0,0 +1,54 @@ +import asyncio +from typing import TYPE_CHECKING, Any, Dict, List + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient +else: + PrismaClient = Any + + +class SpendUpdateQueue: + """ + Handles buffering database `UPDATE` transactions in Redis before committing them to the database + + This is to prevent deadlocks and improve reliability + """ + + def __init__( + self, + ): + self.update_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + + async def add_update(self, update: Dict[str, Any]) -> None: + """Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}.""" + await self.update_queue.put(update) + + async def flush_all_updates_from_in_memory_queue(self) -> List[Dict[str, Any]]: + """Get all updates from the queue.""" + updates: List[Dict[str, Any]] = [] + while not self.update_queue.empty(): + updates.append(await self.update_queue.get()) + return updates + + async def flush_and_get_all_aggregated_updates_by_entity_type( + self, + ) -> Dict[str, Any]: + """Flush all updates from the queue and return all updates aggregated by entity type.""" + updates = await self.flush_all_updates_from_in_memory_queue() + return self.aggregate_updates_by_entity_type(updates) + + def aggregate_updates_by_entity_type( + self, updates: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Aggregate updates by entity type.""" + aggregated_updates = {} + for update in updates: + entity_type = update["entity_type"] + entity_id = update["entity_id"] + amount = update["amount"] + if entity_type not in aggregated_updates: + aggregated_updates[entity_type] = {} + if entity_id not in aggregated_updates[entity_type]: + aggregated_updates[entity_type][entity_id] = 0 + aggregated_updates[entity_type][entity_id] += amount + return aggregated_updates