use redis update buffer class

This commit is contained in:
Ishaan Jaff 2025-03-27 19:12:51 -07:00
parent ad72078167
commit 963791bbb5
3 changed files with 147 additions and 36 deletions

View file

@ -10,26 +10,24 @@ import os
import time import time
import traceback import traceback
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache, RedisCache, RedisClusterCache
from litellm.proxy._types import ( from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES, DB_CONNECTION_ERROR_TYPES,
Litellm_EntityType, Litellm_EntityType,
LiteLLM_UserTable, LiteLLM_UserTable,
SpendLogsPayload, SpendLogsPayload,
) )
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
from litellm.proxy.utils import (
PrismaClient, if TYPE_CHECKING:
ProxyLogging, from litellm.proxy.utils import PrismaClient, ProxyLogging
ProxyUpdateSpend, else:
_raise_failed_update_spend_exception, PrismaClient = Any
hash_token, ProxyLogging = Any
)
from litellm.secret_managers.main import str_to_bool
class DBSpendUpdateWriter: class DBSpendUpdateWriter:
@ -40,6 +38,12 @@ class DBSpendUpdateWriter:
2. Reading increments from redis or in memory list of transactions and committing them to db 2. Reading increments from redis or in memory list of transactions and committing them to db
""" """
def __init__(
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
):
self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache)
@staticmethod @staticmethod
async def update_database( async def update_database(
# LiteLLM management object fields # LiteLLM management object fields
@ -61,6 +65,7 @@ class DBSpendUpdateWriter:
prisma_client, prisma_client,
user_api_key_cache, user_api_key_cache,
) )
from litellm.proxy.utils import ProxyUpdateSpend, hash_token
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -315,6 +320,10 @@ class DBSpendUpdateWriter:
response_cost: Optional[float], response_cost: Optional[float],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
): ):
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
try: try:
if prisma_client: if prisma_client:
payload = get_logging_payload( payload = get_logging_payload(
@ -360,8 +369,8 @@ class DBSpendUpdateWriter:
) )
return prisma_client return prisma_client
@staticmethod async def db_update_spend_transaction_handler(
async def db_spend_transaction_handler( self,
prisma_client: PrismaClient, prisma_client: PrismaClient,
n_retry_times: int, n_retry_times: int,
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
@ -383,8 +392,10 @@ class DBSpendUpdateWriter:
else: else:
- Regular flow of this method - Regular flow of this method
""" """
if DBSpendUpdateWriter._should_commit_spend_updates_to_redis(): if RedisUpdateBuffer._should_commit_spend_updates_to_redis():
pass await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
prisma_client=prisma_client,
)
if DBSpendUpdateWriter._should_commit_spend_updates_to_db(): if DBSpendUpdateWriter._should_commit_spend_updates_to_db():
await DBSpendUpdateWriter._commit_spend_updates_to_db( await DBSpendUpdateWriter._commit_spend_updates_to_db(
@ -395,25 +406,6 @@ class DBSpendUpdateWriter:
pass pass
@staticmethod
def _should_commit_spend_updates_to_redis() -> bool:
"""
Checks if the Pod should commit spend updates to Redis
This setting enables buffering database transactions in Redis
to improve reliability and reduce database contention
"""
from litellm.proxy.proxy_server import general_settings
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
general_settings.get("use_redis_transaction_buffer", False)
)
if isinstance(_use_redis_transaction_buffer, str):
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
if _use_redis_transaction_buffer is None:
return False
return _use_redis_transaction_buffer
@staticmethod @staticmethod
async def _commit_spend_updates_to_redis( async def _commit_spend_updates_to_redis(
prisma_client: PrismaClient, prisma_client: PrismaClient,
@ -439,8 +431,14 @@ class DBSpendUpdateWriter:
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
): ):
""" """
Commits all the spend updates to the Database Commits all the spend `UPDATE` transactions to the Database
""" """
from litellm.proxy.utils import (
ProxyUpdateSpend,
_raise_failed_update_spend_exception,
)
### UPDATE USER TABLE ### ### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0: if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):

View file

@ -0,0 +1,110 @@
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast
from litellm.caching import RedisCache, RedisClusterCache
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
else:
PrismaClient = Any
class DBSpendUpdateTransactions(TypedDict):
user_list_transactons: Dict[str, float]
end_user_list_transactons: Dict[str, float]
key_list_transactons: Dict[str, float]
team_list_transactons: Dict[str, float]
team_member_list_transactons: Dict[str, float]
org_list_transactons: Dict[str, float]
class RedisUpdateBuffer:
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
def __init__(
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
):
self.redis_cache = redis_cache
@staticmethod
def _should_commit_spend_updates_to_redis() -> bool:
"""
Checks if the Pod should commit spend updates to Redis
This setting enables buffering database transactions in Redis
to improve reliability and reduce database contention
"""
from litellm.proxy.proxy_server import general_settings
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
general_settings.get("use_redis_transaction_buffer", False)
)
if isinstance(_use_redis_transaction_buffer, str):
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
if _use_redis_transaction_buffer is None:
return False
return _use_redis_transaction_buffer
async def store_in_memory_spend_updates_in_redis(
self,
prisma_client: PrismaClient,
):
"""
Stores the in-memory spend updates to Redis
Each transaction is a dict stored as following:
- key is the entity id
- value is the spend amount
```
{
"0929880201": 10,
"0929880202": 20,
"0929880203": 30,
}
```
"""
IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = (
DBSpendUpdateTransactions(
user_list_transactons=prisma_client.user_list_transactons,
end_user_list_transactons=prisma_client.end_user_list_transactons,
key_list_transactons=prisma_client.key_list_transactons,
team_list_transactons=prisma_client.team_list_transactons,
team_member_list_transactons=prisma_client.team_member_list_transactons,
org_list_transactons=prisma_client.org_list_transactons,
)
)
for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items():
await self.increment_all_transaction_objects_in_redis(
key=key,
transactions=cast(Dict, _transactions),
)
async def increment_all_transaction_objects_in_redis(
self,
key: str,
transactions: Dict,
):
"""
Increments all transaction objects in Redis
"""
if self.redis_cache is None:
return
for transaction_id, transaction_amount in transactions.items():
await self.redis_cache.async_increment(
key=f"{key}:{transaction_id}",
value=transaction_amount,
)
async def get_all_update_transactions_from_redis(self):
pass

View file

@ -265,6 +265,9 @@ class ProxyLogging:
) )
self.premium_user = premium_user self.premium_user = premium_user
self.service_logging_obj = ServiceLogging() self.service_logging_obj = ServiceLogging()
self.db_spend_update_writer = DBSpendUpdateWriter(
redis_cache=self.internal_usage_cache.dual_cache.redis_cache
)
def startup_event( def startup_event(
self, self,
@ -2675,7 +2678,7 @@ async def update_spend( # noqa: PLR0915
spend_logs: list, spend_logs: list,
""" """
n_retry_times = 3 n_retry_times = 3
await DBSpendUpdateWriter._commit_spend_updates_to_db( await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
prisma_client=prisma_client, prisma_client=prisma_client,
n_retry_times=n_retry_times, n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,