diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index d85e69c472..00f60a76f2 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -10,26 +10,24 @@ import os import time import traceback from datetime import datetime, timedelta -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching import DualCache, RedisCache, RedisClusterCache from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, Litellm_EntityType, LiteLLM_UserTable, SpendLogsPayload, ) -from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload -from litellm.proxy.utils import ( - PrismaClient, - ProxyLogging, - ProxyUpdateSpend, - _raise_failed_update_spend_exception, - hash_token, -) -from litellm.secret_managers.main import str_to_bool +from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient, ProxyLogging +else: + PrismaClient = Any + ProxyLogging = Any class DBSpendUpdateWriter: @@ -40,6 +38,12 @@ class DBSpendUpdateWriter: 2. Reading increments from redis or in memory list of transactions and committing them to db """ + def __init__( + self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None + ): + self.redis_cache = redis_cache + self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache) + @staticmethod async def update_database( # LiteLLM management object fields @@ -61,6 +65,7 @@ class DBSpendUpdateWriter: prisma_client, user_api_key_cache, ) + from litellm.proxy.utils import ProxyUpdateSpend, hash_token try: verbose_proxy_logger.debug( @@ -315,6 +320,10 @@ class DBSpendUpdateWriter: response_cost: Optional[float], prisma_client: Optional[PrismaClient], ): + from litellm.proxy.spend_tracking.spend_tracking_utils import ( + get_logging_payload, + ) + try: if prisma_client: payload = get_logging_payload( @@ -360,8 +369,8 @@ class DBSpendUpdateWriter: ) return prisma_client - @staticmethod - async def db_spend_transaction_handler( + async def db_update_spend_transaction_handler( + self, prisma_client: PrismaClient, n_retry_times: int, proxy_logging_obj: ProxyLogging, @@ -383,8 +392,10 @@ class DBSpendUpdateWriter: else: - Regular flow of this method """ - if DBSpendUpdateWriter._should_commit_spend_updates_to_redis(): - pass + if RedisUpdateBuffer._should_commit_spend_updates_to_redis(): + await self.redis_update_buffer.store_in_memory_spend_updates_in_redis( + prisma_client=prisma_client, + ) if DBSpendUpdateWriter._should_commit_spend_updates_to_db(): await DBSpendUpdateWriter._commit_spend_updates_to_db( @@ -395,25 +406,6 @@ class DBSpendUpdateWriter: pass - @staticmethod - def _should_commit_spend_updates_to_redis() -> bool: - """ - Checks if the Pod should commit spend updates to Redis - - This setting enables buffering database transactions in Redis - to improve reliability and reduce database contention - """ - from litellm.proxy.proxy_server import general_settings - - _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( - general_settings.get("use_redis_transaction_buffer", False) - ) - if isinstance(_use_redis_transaction_buffer, str): - _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) - if _use_redis_transaction_buffer is None: - return False - return _use_redis_transaction_buffer - @staticmethod async def _commit_spend_updates_to_redis( prisma_client: PrismaClient, @@ -439,8 +431,14 @@ class DBSpendUpdateWriter: proxy_logging_obj: ProxyLogging, ): """ - Commits all the spend updates to the Database + Commits all the spend `UPDATE` transactions to the Database + """ + from litellm.proxy.utils import ( + ProxyUpdateSpend, + _raise_failed_update_spend_exception, + ) + ### UPDATE USER TABLE ### if len(prisma_client.user_list_transactons.keys()) > 0: for i in range(n_retry_times + 1): diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py new file mode 100644 index 0000000000..843ec445a3 --- /dev/null +++ b/litellm/proxy/db/redis_update_buffer.py @@ -0,0 +1,110 @@ +""" +Handles buffering database `UPDATE` transactions in Redis before committing them to the database + +This is to prevent deadlocks and improve reliability +""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast + +from litellm.caching import RedisCache, RedisClusterCache +from litellm.secret_managers.main import str_to_bool + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient +else: + PrismaClient = Any + + +class DBSpendUpdateTransactions(TypedDict): + user_list_transactons: Dict[str, float] + end_user_list_transactons: Dict[str, float] + key_list_transactons: Dict[str, float] + team_list_transactons: Dict[str, float] + team_member_list_transactons: Dict[str, float] + org_list_transactons: Dict[str, float] + + +class RedisUpdateBuffer: + """ + Handles buffering database `UPDATE` transactions in Redis before committing them to the database + + This is to prevent deadlocks and improve reliability + """ + + def __init__( + self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None + ): + self.redis_cache = redis_cache + + @staticmethod + def _should_commit_spend_updates_to_redis() -> bool: + """ + Checks if the Pod should commit spend updates to Redis + + This setting enables buffering database transactions in Redis + to improve reliability and reduce database contention + """ + from litellm.proxy.proxy_server import general_settings + + _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( + general_settings.get("use_redis_transaction_buffer", False) + ) + if isinstance(_use_redis_transaction_buffer, str): + _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) + if _use_redis_transaction_buffer is None: + return False + return _use_redis_transaction_buffer + + async def store_in_memory_spend_updates_in_redis( + self, + prisma_client: PrismaClient, + ): + """ + Stores the in-memory spend updates to Redis + + Each transaction is a dict stored as following: + - key is the entity id + - value is the spend amount + + ``` + { + "0929880201": 10, + "0929880202": 20, + "0929880203": 30, + } + ``` + """ + IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = ( + DBSpendUpdateTransactions( + user_list_transactons=prisma_client.user_list_transactons, + end_user_list_transactons=prisma_client.end_user_list_transactons, + key_list_transactons=prisma_client.key_list_transactons, + team_list_transactons=prisma_client.team_list_transactons, + team_member_list_transactons=prisma_client.team_member_list_transactons, + org_list_transactons=prisma_client.org_list_transactons, + ) + ) + for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items(): + await self.increment_all_transaction_objects_in_redis( + key=key, + transactions=cast(Dict, _transactions), + ) + + async def increment_all_transaction_objects_in_redis( + self, + key: str, + transactions: Dict, + ): + """ + Increments all transaction objects in Redis + """ + if self.redis_cache is None: + return + for transaction_id, transaction_amount in transactions.items(): + await self.redis_cache.async_increment( + key=f"{key}:{transaction_id}", + value=transaction_amount, + ) + + async def get_all_update_transactions_from_redis(self): + pass diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 435bf38c0e..c208f60529 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -265,6 +265,9 @@ class ProxyLogging: ) self.premium_user = premium_user self.service_logging_obj = ServiceLogging() + self.db_spend_update_writer = DBSpendUpdateWriter( + redis_cache=self.internal_usage_cache.dual_cache.redis_cache + ) def startup_event( self, @@ -2675,7 +2678,7 @@ async def update_spend( # noqa: PLR0915 spend_logs: list, """ n_retry_times = 3 - await DBSpendUpdateWriter._commit_spend_updates_to_db( + await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler( prisma_client=prisma_client, n_retry_times=n_retry_times, proxy_logging_obj=proxy_logging_obj,