diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index c49d5aa989..0183fa433b 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache, RedisCache, RedisClusterCache +from litellm.caching import DualCache, RedisCache from litellm.constants import DB_SPEND_UPDATE_JOB_NAME from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, @@ -44,12 +44,13 @@ class DBSpendUpdateWriter: """ def __init__( - self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None + self, + redis_cache: Optional[RedisCache] = None, ): from litellm.proxy.proxy_server import prisma_client self.redis_cache = redis_cache - self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache) + self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) self.pod_leader_manager = PodLockManager( cronjob_id=DB_SPEND_UPDATE_JOB_NAME, prisma_client=prisma_client, @@ -408,6 +409,7 @@ class DBSpendUpdateWriter: prisma_client=prisma_client, ) + # Only commit from redis to db if this pod is the leader if await self.pod_leader_manager.acquire_lock(): db_spend_update_transactions = ( await self.redis_update_buffer.get_all_update_transactions_from_redis() diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index d3591032ff..ea7f14d321 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -6,7 +6,8 @@ This is to prevent deadlocks and improve reliability from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast -from litellm.caching import RedisCache, RedisClusterCache +from litellm._logging import verbose_proxy_logger +from litellm.caching import RedisCache from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: @@ -32,7 +33,8 @@ class RedisUpdateBuffer: """ def __init__( - self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None + self, + redis_cache: Optional[RedisCache] = None, ): self.redis_cache = redis_cache @@ -99,6 +101,9 @@ class RedisUpdateBuffer: Increments all transaction objects in Redis """ if self.redis_cache is None: + verbose_proxy_logger.debug( + "redis_cache is None, skipping increment_all_transaction_objects_in_redis" + ) return for transaction_id, transaction_amount in transactions.items(): await self.redis_cache.async_increment( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 4912a35f89..a1be54421b 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -6,4 +6,9 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ general_settings: - allow_requests_on_db_unavailable: True \ No newline at end of file + use_redis_transaction_buffer: True + +litellm_settings: + cache: true + cache_params: + type: redis diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c208f60529..040c6c14ef 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -265,9 +265,7 @@ class ProxyLogging: ) self.premium_user = premium_user self.service_logging_obj = ServiceLogging() - self.db_spend_update_writer = DBSpendUpdateWriter( - redis_cache=self.internal_usage_cache.dual_cache.redis_cache - ) + self.db_spend_update_writer = DBSpendUpdateWriter() def startup_event( self, @@ -340,6 +338,7 @@ class ProxyLogging: if redis_cache is not None: self.internal_usage_cache.dual_cache.redis_cache = redis_cache + self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache def _init_litellm_callbacks(self, llm_router: Optional[Router] = None): litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore