mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
"""
|
|
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
|
|
|
This is to prevent deadlocks and improve reliability
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast
|
|
|
|
from litellm.caching import RedisCache, RedisClusterCache
|
|
from litellm.secret_managers.main import str_to_bool
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.proxy.utils import PrismaClient
|
|
else:
|
|
PrismaClient = Any
|
|
|
|
|
|
class DBSpendUpdateTransactions(TypedDict):
|
|
user_list_transactons: Dict[str, float]
|
|
end_user_list_transactons: Dict[str, float]
|
|
key_list_transactons: Dict[str, float]
|
|
team_list_transactons: Dict[str, float]
|
|
team_member_list_transactons: Dict[str, float]
|
|
org_list_transactons: Dict[str, float]
|
|
|
|
|
|
class RedisUpdateBuffer:
|
|
"""
|
|
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
|
|
|
This is to prevent deadlocks and improve reliability
|
|
"""
|
|
|
|
def __init__(
|
|
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
|
|
):
|
|
self.redis_cache = redis_cache
|
|
|
|
@staticmethod
|
|
def _should_commit_spend_updates_to_redis() -> bool:
|
|
"""
|
|
Checks if the Pod should commit spend updates to Redis
|
|
|
|
This setting enables buffering database transactions in Redis
|
|
to improve reliability and reduce database contention
|
|
"""
|
|
from litellm.proxy.proxy_server import general_settings
|
|
|
|
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
|
|
general_settings.get("use_redis_transaction_buffer", False)
|
|
)
|
|
if isinstance(_use_redis_transaction_buffer, str):
|
|
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
|
if _use_redis_transaction_buffer is None:
|
|
return False
|
|
return _use_redis_transaction_buffer
|
|
|
|
async def store_in_memory_spend_updates_in_redis(
|
|
self,
|
|
prisma_client: PrismaClient,
|
|
):
|
|
"""
|
|
Stores the in-memory spend updates to Redis
|
|
|
|
Each transaction is a dict stored as following:
|
|
- key is the entity id
|
|
- value is the spend amount
|
|
|
|
```
|
|
{
|
|
"0929880201": 10,
|
|
"0929880202": 20,
|
|
"0929880203": 30,
|
|
}
|
|
```
|
|
"""
|
|
IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = (
|
|
DBSpendUpdateTransactions(
|
|
user_list_transactons=prisma_client.user_list_transactons,
|
|
end_user_list_transactons=prisma_client.end_user_list_transactons,
|
|
key_list_transactons=prisma_client.key_list_transactons,
|
|
team_list_transactons=prisma_client.team_list_transactons,
|
|
team_member_list_transactons=prisma_client.team_member_list_transactons,
|
|
org_list_transactons=prisma_client.org_list_transactons,
|
|
)
|
|
)
|
|
for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items():
|
|
await self.increment_all_transaction_objects_in_redis(
|
|
key=key,
|
|
transactions=cast(Dict, _transactions),
|
|
)
|
|
|
|
async def increment_all_transaction_objects_in_redis(
|
|
self,
|
|
key: str,
|
|
transactions: Dict,
|
|
):
|
|
"""
|
|
Increments all transaction objects in Redis
|
|
"""
|
|
if self.redis_cache is None:
|
|
return
|
|
for transaction_id, transaction_amount in transactions.items():
|
|
await self.redis_cache.async_increment(
|
|
key=f"{key}:{transaction_id}",
|
|
value=transaction_amount,
|
|
)
|
|
|
|
async def get_all_update_transactions_from_redis(
|
|
self,
|
|
) -> Optional[DBSpendUpdateTransactions]:
|
|
"""
|
|
Gets all the update transactions from Redis
|
|
"""
|
|
if self.redis_cache is None:
|
|
return None
|
|
expected_keys = [
|
|
"user_list_transactons",
|
|
"end_user_list_transactons",
|
|
"key_list_transactons",
|
|
"team_list_transactons",
|
|
"team_member_list_transactons",
|
|
"org_list_transactons",
|
|
]
|
|
result = await self.redis_cache.async_batch_get_cache(expected_keys)
|
|
if result is None:
|
|
return None
|
|
return DBSpendUpdateTransactions(
|
|
user_list_transactons=result.get("user_list_transactons", {}),
|
|
end_user_list_transactons=result.get("end_user_list_transactons", {}),
|
|
key_list_transactons=result.get("key_list_transactons", {}),
|
|
team_list_transactons=result.get("team_list_transactons", {}),
|
|
team_member_list_transactons=result.get("team_member_list_transactons", {}),
|
|
org_list_transactons=result.get("org_list_transactons", {}),
|
|
)
|