mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
use redis update buffer class
This commit is contained in:
parent
ad72078167
commit
963791bbb5
3 changed files with 147 additions and 36 deletions
|
@ -10,26 +10,24 @@ import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache, RedisCache, RedisClusterCache
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
DB_CONNECTION_ERROR_TYPES,
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
Litellm_EntityType,
|
Litellm_EntityType,
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
SpendLogsPayload,
|
SpendLogsPayload,
|
||||||
)
|
)
|
||||||
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
|
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
||||||
from litellm.proxy.utils import (
|
|
||||||
PrismaClient,
|
if TYPE_CHECKING:
|
||||||
ProxyLogging,
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
ProxyUpdateSpend,
|
else:
|
||||||
_raise_failed_update_spend_exception,
|
PrismaClient = Any
|
||||||
hash_token,
|
ProxyLogging = Any
|
||||||
)
|
|
||||||
from litellm.secret_managers.main import str_to_bool
|
|
||||||
|
|
||||||
|
|
||||||
class DBSpendUpdateWriter:
|
class DBSpendUpdateWriter:
|
||||||
|
@ -40,6 +38,12 @@ class DBSpendUpdateWriter:
|
||||||
2. Reading increments from redis or in memory list of transactions and committing them to db
|
2. Reading increments from redis or in memory list of transactions and committing them to db
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
|
||||||
|
):
|
||||||
|
self.redis_cache = redis_cache
|
||||||
|
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_database(
|
async def update_database(
|
||||||
# LiteLLM management object fields
|
# LiteLLM management object fields
|
||||||
|
@ -61,6 +65,7 @@ class DBSpendUpdateWriter:
|
||||||
prisma_client,
|
prisma_client,
|
||||||
user_api_key_cache,
|
user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.utils import ProxyUpdateSpend, hash_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -315,6 +320,10 @@ class DBSpendUpdateWriter:
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
):
|
):
|
||||||
|
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||||
|
get_logging_payload,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
payload = get_logging_payload(
|
payload = get_logging_payload(
|
||||||
|
@ -360,8 +369,8 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
@staticmethod
|
async def db_update_spend_transaction_handler(
|
||||||
async def db_spend_transaction_handler(
|
self,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
n_retry_times: int,
|
n_retry_times: int,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
@ -383,8 +392,10 @@ class DBSpendUpdateWriter:
|
||||||
else:
|
else:
|
||||||
- Regular flow of this method
|
- Regular flow of this method
|
||||||
"""
|
"""
|
||||||
if DBSpendUpdateWriter._should_commit_spend_updates_to_redis():
|
if RedisUpdateBuffer._should_commit_spend_updates_to_redis():
|
||||||
pass
|
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
|
||||||
if DBSpendUpdateWriter._should_commit_spend_updates_to_db():
|
if DBSpendUpdateWriter._should_commit_spend_updates_to_db():
|
||||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
||||||
|
@ -395,25 +406,6 @@ class DBSpendUpdateWriter:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _should_commit_spend_updates_to_redis() -> bool:
|
|
||||||
"""
|
|
||||||
Checks if the Pod should commit spend updates to Redis
|
|
||||||
|
|
||||||
This setting enables buffering database transactions in Redis
|
|
||||||
to improve reliability and reduce database contention
|
|
||||||
"""
|
|
||||||
from litellm.proxy.proxy_server import general_settings
|
|
||||||
|
|
||||||
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
|
|
||||||
general_settings.get("use_redis_transaction_buffer", False)
|
|
||||||
)
|
|
||||||
if isinstance(_use_redis_transaction_buffer, str):
|
|
||||||
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
|
||||||
if _use_redis_transaction_buffer is None:
|
|
||||||
return False
|
|
||||||
return _use_redis_transaction_buffer
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _commit_spend_updates_to_redis(
|
async def _commit_spend_updates_to_redis(
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
|
@ -439,8 +431,14 @@ class DBSpendUpdateWriter:
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Commits all the spend updates to the Database
|
Commits all the spend `UPDATE` transactions to the Database
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.utils import (
|
||||||
|
ProxyUpdateSpend,
|
||||||
|
_raise_failed_update_spend_exception,
|
||||||
|
)
|
||||||
|
|
||||||
### UPDATE USER TABLE ###
|
### UPDATE USER TABLE ###
|
||||||
if len(prisma_client.user_list_transactons.keys()) > 0:
|
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||||
for i in range(n_retry_times + 1):
|
for i in range(n_retry_times + 1):
|
||||||
|
|
110
litellm/proxy/db/redis_update_buffer.py
Normal file
110
litellm/proxy/db/redis_update_buffer.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
"""
|
||||||
|
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||||
|
|
||||||
|
This is to prevent deadlocks and improve reliability
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict, Union, cast
|
||||||
|
|
||||||
|
from litellm.caching import RedisCache, RedisClusterCache
|
||||||
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
else:
|
||||||
|
PrismaClient = Any
|
||||||
|
|
||||||
|
|
||||||
|
class DBSpendUpdateTransactions(TypedDict):
|
||||||
|
user_list_transactons: Dict[str, float]
|
||||||
|
end_user_list_transactons: Dict[str, float]
|
||||||
|
key_list_transactons: Dict[str, float]
|
||||||
|
team_list_transactons: Dict[str, float]
|
||||||
|
team_member_list_transactons: Dict[str, float]
|
||||||
|
org_list_transactons: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class RedisUpdateBuffer:
|
||||||
|
"""
|
||||||
|
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||||
|
|
||||||
|
This is to prevent deadlocks and improve reliability
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None
|
||||||
|
):
|
||||||
|
self.redis_cache = redis_cache
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_commit_spend_updates_to_redis() -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the Pod should commit spend updates to Redis
|
||||||
|
|
||||||
|
This setting enables buffering database transactions in Redis
|
||||||
|
to improve reliability and reduce database contention
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import general_settings
|
||||||
|
|
||||||
|
_use_redis_transaction_buffer: Optional[Union[bool, str]] = (
|
||||||
|
general_settings.get("use_redis_transaction_buffer", False)
|
||||||
|
)
|
||||||
|
if isinstance(_use_redis_transaction_buffer, str):
|
||||||
|
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
||||||
|
if _use_redis_transaction_buffer is None:
|
||||||
|
return False
|
||||||
|
return _use_redis_transaction_buffer
|
||||||
|
|
||||||
|
async def store_in_memory_spend_updates_in_redis(
|
||||||
|
self,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stores the in-memory spend updates to Redis
|
||||||
|
|
||||||
|
Each transaction is a dict stored as following:
|
||||||
|
- key is the entity id
|
||||||
|
- value is the spend amount
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"0929880201": 10,
|
||||||
|
"0929880202": 20,
|
||||||
|
"0929880203": 30,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
IN_MEMORY_UPDATE_TRANSACTIONS: DBSpendUpdateTransactions = (
|
||||||
|
DBSpendUpdateTransactions(
|
||||||
|
user_list_transactons=prisma_client.user_list_transactons,
|
||||||
|
end_user_list_transactons=prisma_client.end_user_list_transactons,
|
||||||
|
key_list_transactons=prisma_client.key_list_transactons,
|
||||||
|
team_list_transactons=prisma_client.team_list_transactons,
|
||||||
|
team_member_list_transactons=prisma_client.team_member_list_transactons,
|
||||||
|
org_list_transactons=prisma_client.org_list_transactons,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for key, _transactions in IN_MEMORY_UPDATE_TRANSACTIONS.items():
|
||||||
|
await self.increment_all_transaction_objects_in_redis(
|
||||||
|
key=key,
|
||||||
|
transactions=cast(Dict, _transactions),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def increment_all_transaction_objects_in_redis(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
transactions: Dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Increments all transaction objects in Redis
|
||||||
|
"""
|
||||||
|
if self.redis_cache is None:
|
||||||
|
return
|
||||||
|
for transaction_id, transaction_amount in transactions.items():
|
||||||
|
await self.redis_cache.async_increment(
|
||||||
|
key=f"{key}:{transaction_id}",
|
||||||
|
value=transaction_amount,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_all_update_transactions_from_redis(self):
|
||||||
|
pass
|
|
@ -265,6 +265,9 @@ class ProxyLogging:
|
||||||
)
|
)
|
||||||
self.premium_user = premium_user
|
self.premium_user = premium_user
|
||||||
self.service_logging_obj = ServiceLogging()
|
self.service_logging_obj = ServiceLogging()
|
||||||
|
self.db_spend_update_writer = DBSpendUpdateWriter(
|
||||||
|
redis_cache=self.internal_usage_cache.dual_cache.redis_cache
|
||||||
|
)
|
||||||
|
|
||||||
def startup_event(
|
def startup_event(
|
||||||
self,
|
self,
|
||||||
|
@ -2675,7 +2678,7 @@ async def update_spend( # noqa: PLR0915
|
||||||
spend_logs: list,
|
spend_logs: list,
|
||||||
"""
|
"""
|
||||||
n_retry_times = 3
|
n_retry_times = 3
|
||||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
n_retry_times=n_retry_times,
|
n_retry_times=n_retry_times,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue