litellm-mirror/litellm/proxy/db/spend_update_queue.py
2025-03-31 19:28:17 -07:00

94 lines
3.6 KiB
Python

import asyncio
from typing import TYPE_CHECKING, Any, Dict, List
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
DBSpendUpdateTransactions,
Litellm_EntityType,
SpendUpdateQueueItem,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
else:
PrismaClient = Any
class SpendUpdateQueue:
"""
In memory buffer for spend updates that should be committed to the database
"""
def __init__(
self,
):
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
async def add_update(self, update: SpendUpdateQueueItem) -> None:
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
async def flush_all_updates_from_in_memory_queue(
self,
) -> List[SpendUpdateQueueItem]:
"""Get all updates from the queue."""
updates: List[SpendUpdateQueueItem] = []
while not self.update_queue.empty():
updates.append(await self.update_queue.get())
return updates
async def flush_and_get_aggregated_db_spend_update_transactions(
self,
) -> DBSpendUpdateTransactions:
"""Flush all updates from the queue and return all updates aggregated by entity type."""
updates = await self.flush_all_updates_from_in_memory_queue()
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
return self.get_aggregated_db_spend_update_transactions(updates)
def get_aggregated_db_spend_update_transactions(
self, updates: List[SpendUpdateQueueItem]
) -> DBSpendUpdateTransactions:
"""Aggregate updates by entity type."""
# Initialize all transaction lists as empty dicts
db_spend_update_transactions = DBSpendUpdateTransactions(
user_list_transactions={},
end_user_list_transactions={},
key_list_transactions={},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
)
# Map entity types to their corresponding transaction dictionary keys
entity_type_to_dict_key = {
Litellm_EntityType.USER: "user_list_transactions",
Litellm_EntityType.END_USER: "end_user_list_transactions",
Litellm_EntityType.KEY: "key_list_transactions",
Litellm_EntityType.TEAM: "team_list_transactions",
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
}
for update in updates:
entity_type = update.get("entity_type")
entity_id = update.get("entity_id")
response_cost = update.get("response_cost")
if entity_type is None or entity_id is None or response_cost is None:
raise ValueError("Invalid update: %s", update)
dict_key = entity_type_to_dict_key.get(entity_type)
if dict_key is None:
continue # Skip unknown entity types
transactions_dict = db_spend_update_transactions[dict_key]
if transactions_dict is None:
transactions_dict = {}
db_spend_update_transactions[dict_key] = transactions_dict
transactions_dict[entity_id] = (
transactions_dict.get(entity_id, 0) + response_cost
)
return db_spend_update_transactions