diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 56f7664c73..f4f045b2a2 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -25,9 +25,12 @@ from litellm.proxy._types import ( SpendLogsPayload, SpendUpdateQueueItem, ) -from litellm.proxy.db.pod_lock_manager import PodLockManager -from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer -from litellm.proxy.db.spend_update_queue import DailySpendUpdateQueue, SpendUpdateQueue +from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( + DailySpendUpdateQueue, +) +from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager +from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer +from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient, ProxyLogging diff --git a/litellm/proxy/db/db_transaction_queue/base_update_queue.py b/litellm/proxy/db/db_transaction_queue/base_update_queue.py new file mode 100644 index 0000000000..b74ed439df --- /dev/null +++ b/litellm/proxy/db/db_transaction_queue/base_update_queue.py @@ -0,0 +1,22 @@ +import asyncio + +from litellm._logging import verbose_proxy_logger + + +class BaseUpdateQueue: + """Base class for spend update queues with common functionality""" + + def __init__(self): + self.update_queue = asyncio.Queue() + + async def add_update(self, update): + """Enqueue an update.""" + verbose_proxy_logger.debug("Adding update to queue: %s", update) + await self.update_queue.put(update) + + async def flush_all_updates_from_in_memory_queue(self): + """Get all updates from the queue.""" + updates = [] + while not self.update_queue.empty(): + updates.append(await self.update_queue.get()) + return updates diff --git a/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py new file mode 100644 index 0000000000..dedb8c8f8f --- /dev/null +++ b/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py @@ -0,0 +1,95 @@ +import asyncio +from typing import Dict, List + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import DailyUserSpendTransaction +from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue + + +class DailySpendUpdateQueue(BaseUpdateQueue): + """ + In memory buffer for daily spend updates that should be committed to the database + + To add a new daily spend update transaction, use the following format: + daily_spend_update_queue.add_update({ + "user1_date_api_key_model_custom_llm_provider": { + "spend": 10, + "prompt_tokens": 100, + "completion_tokens": 100, + } + }) + + Queue contains a list of daily spend update transactions + + eg + queue = [ + { + "user1_date_api_key_model_custom_llm_provider": { + "spend": 10, + "prompt_tokens": 100, + "completion_tokens": 100, + "api_requests": 100, + "successful_requests": 100, + "failed_requests": 100, + } + }, + { + "user2_date_api_key_model_custom_llm_provider": { + "spend": 10, + "prompt_tokens": 100, + "completion_tokens": 100, + "api_requests": 100, + "successful_requests": 100, + "failed_requests": 100, + } + } + ] + """ + + def __init__(self): + super().__init__() + self.update_queue: asyncio.Queue[ + Dict[str, DailyUserSpendTransaction] + ] = asyncio.Queue() + + async def flush_and_get_aggregated_daily_spend_update_transactions( + self, + ) -> Dict[str, DailyUserSpendTransaction]: + """Get all updates from the queue and return all updates aggregated by daily_transaction_key.""" + updates = await self.flush_all_updates_from_in_memory_queue() + aggregated_daily_spend_update_transactions = ( + DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions( + updates + ) + ) + verbose_proxy_logger.debug( + "Aggregated daily spend update transactions: %s", + aggregated_daily_spend_update_transactions, + ) + return aggregated_daily_spend_update_transactions + + @staticmethod + def get_aggregated_daily_spend_update_transactions( + updates: List[Dict[str, DailyUserSpendTransaction]] + ) -> Dict[str, DailyUserSpendTransaction]: + """Aggregate updates by daily_transaction_key.""" + aggregated_daily_spend_update_transactions: Dict[ + str, DailyUserSpendTransaction + ] = {} + for _update in updates: + for _key, payload in _update.items(): + if _key in aggregated_daily_spend_update_transactions: + daily_transaction = aggregated_daily_spend_update_transactions[_key] + daily_transaction["spend"] += payload["spend"] + daily_transaction["prompt_tokens"] += payload["prompt_tokens"] + daily_transaction["completion_tokens"] += payload[ + "completion_tokens" + ] + daily_transaction["api_requests"] += payload["api_requests"] + daily_transaction["successful_requests"] += payload[ + "successful_requests" + ] + daily_transaction["failed_requests"] += payload["failed_requests"] + else: + aggregated_daily_spend_update_transactions[_key] = payload + return aggregated_daily_spend_update_transactions diff --git a/litellm/proxy/db/pod_lock_manager.py b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py similarity index 100% rename from litellm/proxy/db/pod_lock_manager.py rename to litellm/proxy/db/db_transaction_queue/pod_lock_manager.py diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py similarity index 98% rename from litellm/proxy/db/redis_update_buffer.py rename to litellm/proxy/db/db_transaction_queue/redis_update_buffer.py index 6ffada4dee..ea1356159a 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py @@ -16,7 +16,10 @@ from litellm.constants import ( ) from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import DailyUserSpendTransaction, DBSpendUpdateTransactions -from litellm.proxy.db.spend_update_queue import DailySpendUpdateQueue, SpendUpdateQueue +from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( + DailySpendUpdateQueue, +) +from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: diff --git a/litellm/proxy/db/spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py similarity index 52% rename from litellm/proxy/db/spend_update_queue.py rename to litellm/proxy/db/db_transaction_queue/spend_update_queue.py index e77f945094..ce181d1478 100644 --- a/litellm/proxy/db/spend_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py @@ -1,44 +1,24 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, List +from typing import List from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ( - DailyUserSpendTransaction, DBSpendUpdateTransactions, Litellm_EntityType, SpendUpdateQueueItem, ) - -if TYPE_CHECKING: - from litellm.proxy.utils import PrismaClient -else: - PrismaClient = Any +from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue -class SpendUpdateQueue: +class SpendUpdateQueue(BaseUpdateQueue): """ In memory buffer for spend updates that should be committed to the database """ - def __init__( - self, - ): + def __init__(self): + super().__init__() self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue() - async def add_update(self, update: SpendUpdateQueueItem) -> None: - """Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}.""" - verbose_proxy_logger.debug("Adding update to queue: %s", update) - await self.update_queue.put(update) - - async def flush_all_updates_from_in_memory_queue( - self, - ) -> List[SpendUpdateQueueItem]: - """Get all updates from the queue.""" - updates: List[SpendUpdateQueueItem] = [] - while not self.update_queue.empty(): - updates.append(await self.update_queue.get()) - return updates - async def flush_and_get_aggregated_db_spend_update_transactions( self, ) -> DBSpendUpdateTransactions: @@ -131,79 +111,3 @@ class SpendUpdateQueue: transactions_dict[entity_id] += response_cost or 0 return db_spend_update_transactions - - -class DailySpendUpdateQueue: - def __init__( - self, - ): - self.update_queue: asyncio.Queue[ - Dict[str, DailyUserSpendTransaction] - ] = asyncio.Queue() - - async def add_update(self, update: Dict[str, DailyUserSpendTransaction]) -> None: - """Enqueue an update. Each update might be a dict like - { - "user_date_api_key_model_custom_llm_provider": { - "spend": 1.2, - "prompt_tokens": 1000, - "completion_tokens": 1000, - "api_requests": 1000, - "successful_requests": 1000, - "failed_requests": 1000, - } - } - .""" - verbose_proxy_logger.debug("Adding update to queue: %s", update) - await self.update_queue.put(update) - - async def flush_all_updates_from_in_memory_queue( - self, - ) -> List[Dict[str, DailyUserSpendTransaction]]: - """Get all updates from the queue.""" - updates: List[Dict[str, DailyUserSpendTransaction]] = [] - while not self.update_queue.empty(): - updates.append(await self.update_queue.get()) - return updates - - async def flush_and_get_aggregated_daily_spend_update_transactions( - self, - ) -> Dict[str, DailyUserSpendTransaction]: - """Get all updates from the queue and return all updates aggregated by daily_transaction_key.""" - updates = await self.flush_all_updates_from_in_memory_queue() - aggregated_daily_spend_update_transactions = ( - DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions( - updates - ) - ) - verbose_proxy_logger.debug( - "Aggregated daily spend update transactions: %s", - aggregated_daily_spend_update_transactions, - ) - return aggregated_daily_spend_update_transactions - - @staticmethod - def get_aggregated_daily_spend_update_transactions( - updates: List[Dict[str, DailyUserSpendTransaction]] - ) -> Dict[str, DailyUserSpendTransaction]: - """Aggregate updates by daily_transaction_key.""" - aggregated_daily_spend_update_transactions: Dict[ - str, DailyUserSpendTransaction - ] = {} - for _update in updates: - for _key, payload in _update.items(): - if _key in aggregated_daily_spend_update_transactions: - daily_transaction = aggregated_daily_spend_update_transactions[_key] - daily_transaction["spend"] += payload["spend"] - daily_transaction["prompt_tokens"] += payload["prompt_tokens"] - daily_transaction["completion_tokens"] += payload[ - "completion_tokens" - ] - daily_transaction["api_requests"] += payload["api_requests"] - daily_transaction["successful_requests"] += payload[ - "successful_requests" - ] - daily_transaction["failed_requests"] += payload["failed_requests"] - else: - aggregated_daily_spend_update_transactions[_key] = payload - return aggregated_daily_spend_update_transactions diff --git a/tests/litellm/proxy/db/test_pod_lock_manager.py b/tests/litellm/proxy/db/test_pod_lock_manager.py index bce7b66409..cde4315837 100644 --- a/tests/litellm/proxy/db/test_pod_lock_manager.py +++ b/tests/litellm/proxy/db/test_pod_lock_manager.py @@ -12,7 +12,7 @@ sys.path.insert( ) # Adds the parent directory to the system path from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS -from litellm.proxy.db.pod_lock_manager import PodLockManager +from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager # Mock Prisma client class diff --git a/tests/litellm/proxy/db/test_spend_update_queue.py b/tests/litellm/proxy/db/test_spend_update_queue.py index 89d494a070..98d3b4e4c7 100644 --- a/tests/litellm/proxy/db/test_spend_update_queue.py +++ b/tests/litellm/proxy/db/test_spend_update_queue.py @@ -7,7 +7,7 @@ import pytest from fastapi.testclient import TestClient from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem -from litellm.proxy.db.spend_update_queue import SpendUpdateQueue +from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue sys.path.insert( 0, os.path.abspath("../../..") diff --git a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py index 7d36bb4791..3522c8e1e2 100644 --- a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py +++ b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py @@ -23,7 +23,7 @@ import asyncio import logging import pytest -from litellm.proxy.db.pod_lock_manager import PodLockManager +from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy.management_endpoints.internal_user_endpoints import (