mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
use redis update buffer class
This commit is contained in:
parent
ad72078167
commit
963791bbb5
3 changed files with 147 additions and 36 deletions
|
@ -10,26 +10,24 @@ import os
|
|||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.caching import DualCache, RedisCache, RedisClusterCache
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_UserTable,
|
||||
SpendLogsPayload,
|
||||
)
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
ProxyLogging,
|
||||
ProxyUpdateSpend,
|
||||
_raise_failed_update_spend_exception,
|
||||
hash_token,
|
||||
)
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
else:
|
||||
PrismaClient = Any
|
||||
ProxyLogging = Any
|
||||
|
||||
|
||||
class DBSpendUpdateWriter:
|
||||
|
@ -40,6 +38,12 @@ class DBSpendUpdateWriter:
|
|||
2. Reading increments from redis or in memory list of transactions and committing them to db
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
|
||||
):
|
||||
self.redis_cache = redis_cache
|
||||
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache)
|
||||
|
||||
@staticmethod
|
||||
async def update_database(
|
||||
# LiteLLM management object fields
|
||||
|
@ -61,6 +65,7 @@ class DBSpendUpdateWriter:
|
|||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyUpdateSpend, hash_token
|
||||
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -315,6 +320,10 @@ class DBSpendUpdateWriter:
|
|||
response_cost: Optional[float],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
):
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client:
|
||||
payload = get_logging_payload(
|
||||
|
@ -360,8 +369,8 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
return prisma_client
|
||||
|
||||
@staticmethod
|
||||
async def db_spend_transaction_handler(
|
||||
async def db_update_spend_transaction_handler(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
n_retry_times: int,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
|
@ -383,8 +392,10 @@ class DBSpendUpdateWriter:
|
|||
else:
|
||||
- Regular flow of this method
|
||||
"""
|
||||
if DBSpendUpdateWriter._should_commit_spend_updates_to_redis():
|
||||
pass
|
||||
if RedisUpdateBuffer._should_commit_spend_updates_to_redis():
|
||||
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
if DBSpendUpdateWriter._should_commit_spend_updates_to_db():
|
||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
||||
|
@ -395,25 +406,6 @@ class DBSpendUpdateWriter:
|
|||
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _should_commit_spend_updates_to_redis() -> bool:
|
||||
"""
|
||||
Checks if the Pod should commit spend updates to Redis
|
||||
|
||||
This setting enables buffering database transactions in Redis
|
||||
to improve reliability and reduce database contention
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
|
||||
general_settings.get("use_redis_transaction_buffer", False)
|
||||
)
|
||||
if isinstance(_use_redis_transaction_buffer, str):
|
||||
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
||||
if _use_redis_transaction_buffer is None:
|
||||
return False
|
||||
return _use_redis_transaction_buffer
|
||||
|
||||
@staticmethod
|
||||
async def _commit_spend_updates_to_redis(
|
||||
prisma_client: PrismaClient,
|
||||
|
@ -439,8 +431,14 @@ class DBSpendUpdateWriter:
|
|||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
"""
|
||||
Commits all the spend updates to the Database
|
||||
Commits all the spend `UPDATE` transactions to the Database
|
||||
|
||||
"""
|
||||
from litellm.proxy.utils import (
|
||||
ProxyUpdateSpend,
|
||||
_raise_failed_update_spend_exception,
|
||||
)
|
||||
|
||||
### UPDATE USER TABLE ###
|
||||
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
|
|
110
litellm/proxy/db/redis_update_buffer.py
Normal file
110
litellm/proxy/db/redis_update_buffer.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast
|
||||
|
||||
from litellm.caching import RedisCache, RedisClusterCache
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
else:
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class DBSpendUpdateTransactions(TypedDict):
|
||||
user_list_transactons: Dict[str, float]
|
||||
end_user_list_transactons: Dict[str, float]
|
||||
key_list_transactons: Dict[str, float]
|
||||
team_list_transactons: Dict[str, float]
|
||||
team_member_list_transactons: Dict[str, float]
|
||||
org_list_transactons: Dict[str, float]
|
||||
|
||||
|
||||
class RedisUpdateBuffer:
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
|
||||
):
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def _should_commit_spend_updates_to_redis() -> bool:
|
||||
"""
|
||||
Checks if the Pod should commit spend updates to Redis
|
||||
|
||||
This setting enables buffering database transactions in Redis
|
||||
to improve reliability and reduce database contention
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
|
||||
general_settings.get("use_redis_transaction_buffer", False)
|
||||
)
|
||||
if isinstance(_use_redis_transaction_buffer, str):
|
||||
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
||||
if _use_redis_transaction_buffer is None:
|
||||
return False
|
||||
return _use_redis_transaction_buffer
|
||||
|
||||
async def store_in_memory_spend_updates_in_redis(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
"""
|
||||
Stores the in-memory spend updates to Redis
|
||||
|
||||
Each transaction is a dict stored as following:
|
||||
- key is the entity id
|
||||
- value is the spend amount
|
||||
|
||||
```
|
||||
{
|
||||
"0929880201": 10,
|
||||
"0929880202": 20,
|
||||
"0929880203": 30,
|
||||
}
|
||||
```
|
||||
"""
|
||||
IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = (
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactons=prisma_client.user_list_transactons,
|
||||
end_user_list_transactons=prisma_client.end_user_list_transactons,
|
||||
key_list_transactons=prisma_client.key_list_transactons,
|
||||
team_list_transactons=prisma_client.team_list_transactons,
|
||||
team_member_list_transactons=prisma_client.team_member_list_transactons,
|
||||
org_list_transactons=prisma_client.org_list_transactons,
|
||||
)
|
||||
)
|
||||
for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items():
|
||||
await self.increment_all_transaction_objects_in_redis(
|
||||
key=key,
|
||||
transactions=cast(Dict, _transactions),
|
||||
)
|
||||
|
||||
async def increment_all_transaction_objects_in_redis(
|
||||
self,
|
||||
key: str,
|
||||
transactions: Dict,
|
||||
):
|
||||
"""
|
||||
Increments all transaction objects in Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return
|
||||
for transaction_id, transaction_amount in transactions.items():
|
||||
await self.redis_cache.async_increment(
|
||||
key=f"{key}:{transaction_id}",
|
||||
value=transaction_amount,
|
||||
)
|
||||
|
||||
async def get_all_update_transactions_from_redis(self):
|
||||
pass
|
|
@ -265,6 +265,9 @@ class ProxyLogging:
|
|||
)
|
||||
self.premium_user = premium_user
|
||||
self.service_logging_obj = ServiceLogging()
|
||||
self.db_spend_update_writer = DBSpendUpdateWriter(
|
||||
redis_cache=self.internal_usage_cache.dual_cache.redis_cache
|
||||
)
|
||||
|
||||
def startup_event(
|
||||
self,
|
||||
|
@ -2675,7 +2678,7 @@ async def update_spend( # noqa: PLR0915
|
|||
spend_logs: list,
|
||||
"""
|
||||
n_retry_times = 3
|
||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
||||
await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
|
||||
prisma_client=prisma_client,
|
||||
n_retry_times=n_retry_times,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue