use redis update buffer class

This commit is contained in:
Ishaan Jaff 2025-03-27 19:12:51 -07:00
parent ad72078167
commit 963791bbb5
3 changed files with 147 additions and 36 deletions

View file

@ -10,26 +10,24 @@ import os
import time
import traceback
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.caching import DualCache, RedisCache, RedisClusterCache
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
Litellm_EntityType,
LiteLLM_UserTable,
SpendLogsPayload,
)
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
from litellm.proxy.utils import (
PrismaClient,
ProxyLogging,
ProxyUpdateSpend,
_raise_failed_update_spend_exception,
hash_token,
)
from litellm.secret_managers.main import str_to_bool
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging
else:
PrismaClient = Any
ProxyLogging = Any
class DBSpendUpdateWriter:
@ -40,6 +38,12 @@ class DBSpendUpdateWriter:
2. Reading increments from redis or in memory list of transactions and committing them to db
"""
def __init__(
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
):
self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache)
@staticmethod
async def update_database(
# LiteLLM management object fields
@ -61,6 +65,7 @@ class DBSpendUpdateWriter:
prisma_client,
user_api_key_cache,
)
from litellm.proxy.utils import ProxyUpdateSpend, hash_token
try:
verbose_proxy_logger.debug(
@ -315,6 +320,10 @@ class DBSpendUpdateWriter:
response_cost: Optional[float],
prisma_client: Optional[PrismaClient],
):
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
try:
if prisma_client:
payload = get_logging_payload(
@ -360,8 +369,8 @@ class DBSpendUpdateWriter:
)
return prisma_client
@staticmethod
async def db_spend_transaction_handler(
async def db_update_spend_transaction_handler(
self,
prisma_client: PrismaClient,
n_retry_times: int,
proxy_logging_obj: ProxyLogging,
@ -383,8 +392,10 @@ class DBSpendUpdateWriter:
else:
- Regular flow of this method
"""
if DBSpendUpdateWriter._should_commit_spend_updates_to_redis():
pass
if RedisUpdateBuffer._should_commit_spend_updates_to_redis():
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
prisma_client=prisma_client,
)
if DBSpendUpdateWriter._should_commit_spend_updates_to_db():
await DBSpendUpdateWriter._commit_spend_updates_to_db(
@ -395,25 +406,6 @@ class DBSpendUpdateWriter:
pass
@staticmethod
def _should_commit_spend_updates_to_redis() -> bool:
"""
Checks if the Pod should commit spend updates to Redis
This setting enables buffering database transactions in Redis
to improve reliability and reduce database contention
"""
from litellm.proxy.proxy_server import general_settings
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
general_settings.get("use_redis_transaction_buffer", False)
)
if isinstance(_use_redis_transaction_buffer, str):
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
if _use_redis_transaction_buffer is None:
return False
return _use_redis_transaction_buffer
@staticmethod
async def _commit_spend_updates_to_redis(
prisma_client: PrismaClient,
@ -439,8 +431,14 @@ class DBSpendUpdateWriter:
proxy_logging_obj: ProxyLogging,
):
"""
Commits all the spend updates to the Database
Commits all the spend `UPDATE` transactions to the Database
"""
from litellm.proxy.utils import (
ProxyUpdateSpend,
_raise_failed_update_spend_exception,
)
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):

View file

@ -0,0 +1,110 @@
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast
from litellm.caching import RedisCache, RedisClusterCache
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
else:
PrismaClient = Any
class DBSpendUpdateTransactions(TypedDict):
user_list_transactons: Dict[str, float]
end_user_list_transactons: Dict[str, float]
key_list_transactons: Dict[str, float]
team_list_transactons: Dict[str, float]
team_member_list_transactons: Dict[str, float]
org_list_transactons: Dict[str, float]
class RedisUpdateBuffer:
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
def __init__(
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
):
self.redis_cache = redis_cache
@staticmethod
def _should_commit_spend_updates_to_redis() -> bool:
"""
Checks if the Pod should commit spend updates to Redis
This setting enables buffering database transactions in Redis
to improve reliability and reduce database contention
"""
from litellm.proxy.proxy_server import general_settings
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
general_settings.get("use_redis_transaction_buffer", False)
)
if isinstance(_use_redis_transaction_buffer, str):
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
if _use_redis_transaction_buffer is None:
return False
return _use_redis_transaction_buffer
async def store_in_memory_spend_updates_in_redis(
self,
prisma_client: PrismaClient,
):
"""
Stores the in-memory spend updates to Redis
Each transaction is a dict stored as following:
- key is the entity id
- value is the spend amount
```
{
"0929880201": 10,
"0929880202": 20,
"0929880203": 30,
}
```
"""
IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = (
DBSpendUpdateTransactions(
user_list_transactons=prisma_client.user_list_transactons,
end_user_list_transactons=prisma_client.end_user_list_transactons,
key_list_transactons=prisma_client.key_list_transactons,
team_list_transactons=prisma_client.team_list_transactons,
team_member_list_transactons=prisma_client.team_member_list_transactons,
org_list_transactons=prisma_client.org_list_transactons,
)
)
for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items():
await self.increment_all_transaction_objects_in_redis(
key=key,
transactions=cast(Dict, _transactions),
)
async def increment_all_transaction_objects_in_redis(
self,
key: str,
transactions: Dict,
):
"""
Increments all transaction objects in Redis
"""
if self.redis_cache is None:
return
for transaction_id, transaction_amount in transactions.items():
await self.redis_cache.async_increment(
key=f"{key}:{transaction_id}",
value=transaction_amount,
)
async def get_all_update_transactions_from_redis(self):
pass

View file

@ -265,6 +265,9 @@ class ProxyLogging:
)
self.premium_user = premium_user
self.service_logging_obj = ServiceLogging()
self.db_spend_update_writer = DBSpendUpdateWriter(
redis_cache=self.internal_usage_cache.dual_cache.redis_cache
)
def startup_event(
self,
@ -2675,7 +2678,7 @@ async def update_spend( # noqa: PLR0915
spend_logs: list,
"""
n_retry_times = 3
await DBSpendUpdateWriter._commit_spend_updates_to_db(
await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
prisma_client=prisma_client,
n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj,