mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(Redis Cluster) - Fixes for using redis cluster + pipeline (#8442)
* update RedisCluster creation * update RedisClusterCache * add redis ClusterCache * update async_set_cache_pipeline * cleanup redis cluster usage * fix redis pipeline * test_init_async_client_returns_same_instance * fix redis cluster * update mypy_path * fix init_redis_cluster * remove stub * test redis commit * ClusterPipeline * fix import * RedisCluster import * fix redis cluster * Potential fix for code scanning alert no. 2129: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix naming of redis cluster integration * test_redis_caching_ttl_pipeline * fix async_set_cache_pipeline --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
parent
b465688afb
commit
b58348746a
7 changed files with 112 additions and 27 deletions
|
@ -14,7 +14,7 @@ import inspect
|
|||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
@ -26,15 +26,20 @@ from .base_cache import BaseCache
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
from redis.asyncio.cluster import ClusterPipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
cluster_pipeline = ClusterPipeline
|
||||
async_redis_client = Redis
|
||||
async_redis_cluster_client = RedisCluster
|
||||
Span = _Span
|
||||
else:
|
||||
pipeline = Any
|
||||
cluster_pipeline = Any
|
||||
async_redis_client = Any
|
||||
async_redis_cluster_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
|
@ -122,7 +127,9 @@ class RedisCache(BaseCache):
|
|||
else:
|
||||
super().__init__() # defaults to 60s
|
||||
|
||||
def init_async_client(self):
|
||||
def init_async_client(
|
||||
self,
|
||||
) -> Union[async_redis_client, async_redis_cluster_client]:
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
return get_redis_async_client(
|
||||
|
@ -345,8 +352,14 @@ class RedisCache(BaseCache):
|
|||
)
|
||||
|
||||
async def _pipeline_helper(
|
||||
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
|
||||
self,
|
||||
pipe: Union[pipeline, cluster_pipeline],
|
||||
cache_list: List[Tuple[Any, Any]],
|
||||
ttl: Optional[float],
|
||||
) -> List:
|
||||
"""
|
||||
Helper function for executing a pipeline of set operations on Redis
|
||||
"""
|
||||
ttl = self.get_ttl(ttl=ttl)
|
||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||
for cache_key, cache_value in cache_list:
|
||||
|
@ -359,7 +372,11 @@ class RedisCache(BaseCache):
|
|||
_td: Optional[timedelta] = None
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||
pipe.set( # type: ignore
|
||||
name=cache_key,
|
||||
value=json_cache_value,
|
||||
ex=_td,
|
||||
)
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
return results
|
||||
|
@ -373,9 +390,8 @@ class RedisCache(BaseCache):
|
|||
# don't waste a network request if there's nothing to set
|
||||
if len(cache_list) == 0:
|
||||
return
|
||||
from redis.asyncio import Redis
|
||||
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
|
||||
print_verbose(
|
||||
|
@ -384,7 +400,7 @@ class RedisCache(BaseCache):
|
|||
cache_value: Any = None
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
async with redis_client.pipeline(transaction=False) as pipe:
|
||||
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
||||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
|
@ -730,7 +746,8 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
Use Redis for bulk read operations
|
||||
"""
|
||||
_redis_client = await self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
key_value_dict = {}
|
||||
start_time = time.time()
|
||||
try:
|
||||
|
@ -822,7 +839,8 @@ class RedisCache(BaseCache):
|
|||
raise e
|
||||
|
||||
async def ping(self) -> bool:
|
||||
_redis_client = self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
print_verbose("Pinging Async Redis Cache")
|
||||
|
@ -858,7 +876,8 @@ class RedisCache(BaseCache):
|
|||
raise e
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
_redis_client = self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
# keys is a list, unpack it so it gets passed as individual elements to delete
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(*keys)
|
||||
|
@ -881,7 +900,8 @@ class RedisCache(BaseCache):
|
|||
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
_redis_client = self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
# keys is str
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(key)
|
||||
|
@ -936,7 +956,7 @@ class RedisCache(BaseCache):
|
|||
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
async with redis_client.pipeline(transaction=False) as pipe:
|
||||
results = await self._pipeline_increment_helper(
|
||||
pipe, increment_list
|
||||
)
|
||||
|
@ -991,7 +1011,8 @@ class RedisCache(BaseCache):
|
|||
Redis ref: https://redis.io/docs/latest/commands/ttl/
|
||||
"""
|
||||
try:
|
||||
_redis_client = await self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
ttl = await redis_client.ttl(key)
|
||||
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue