fix db spend update buffer

This commit is contained in:
Ishaan Jaff 2025-03-27 22:34:15 -07:00
parent 1bfffadd05
commit fc46f6b861
4 changed files with 20 additions and 9 deletions

View file

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache, RedisCache, RedisClusterCache
from litellm.caching import DualCache, RedisCache
from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
@ -44,12 +44,13 @@ class DBSpendUpdateWriter:
"""
def __init__(
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
self,
redis_cache: Optional[RedisCache] = None,
):
from litellm.proxy.proxy_server import prisma_client
self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache)
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
self.pod_leader_manager = PodLockManager(
cronjob_id=DB_SPEND_UPDATE_JOB_NAME,
prisma_client=prisma_client,
@ -408,6 +409,7 @@ class DBSpendUpdateWriter:
prisma_client=prisma_client,
)
# Only commit from redis to db if this pod is the leader
if await self.pod_leader_manager.acquire_lock():
db_spend_update_transactions = (
await self.redis_update_buffer.get_all_update_transactions_from_redis()

View file

@ -6,7 +6,8 @@ This is to prevent deadlocks and improve reliability
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast
from litellm.caching import RedisCache, RedisClusterCache
from litellm._logging import verbose_proxy_logger
from litellm.caching import RedisCache
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
@ -32,7 +33,8 @@ class RedisUpdateBuffer:
"""
def __init__(
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
self,
redis_cache: Optional[RedisCache] = None,
):
self.redis_cache = redis_cache
@ -99,6 +101,9 @@ class RedisUpdateBuffer:
Increments all transaction objects in Redis
"""
if self.redis_cache is None:
verbose_proxy_logger.debug(
"redis_cache is None, skipping increment_all_transaction_objects_in_redis"
)
return
for transaction_id, transaction_amount in transactions.items():
await self.redis_cache.async_increment(

View file

@ -6,4 +6,9 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
allow_requests_on_db_unavailable: True
use_redis_transaction_buffer: True
litellm_settings:
cache: true
cache_params:
type: redis

View file

@ -265,9 +265,7 @@ class ProxyLogging:
)
self.premium_user = premium_user
self.service_logging_obj = ServiceLogging()
self.db_spend_update_writer = DBSpendUpdateWriter(
redis_cache=self.internal_usage_cache.dual_cache.redis_cache
)
self.db_spend_update_writer = DBSpendUpdateWriter()
def startup_event(
self,
@ -340,6 +338,7 @@ class ProxyLogging:
if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore