mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Redis Cluster) - Fixes for using redis cluster + pipeline (#8442)
* update RedisCluster creation * update RedisClusterCache * add redis ClusterCache * update async_set_cache_pipeline * cleanup redis cluster usage * fix redis pipeline * test_init_async_client_returns_same_instance * fix redis cluster * update mypy_path * fix init_redis_cluster * remove stub * test redis commit * ClusterPipeline * fix import * RedisCluster import * fix redis cluster * Potential fix for code scanning alert no. 2129: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix naming of redis cluster integration * test_redis_caching_ttl_pipeline * fix async_set_cache_pipeline --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
parent
b710407763
commit
40e3af0428
7 changed files with 112 additions and 27 deletions
|
@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
|
"init_redis_cluster: startup nodes are being initialized."
|
||||||
)
|
)
|
||||||
from redis.cluster import ClusterNode
|
from redis.cluster import ClusterNode
|
||||||
|
|
||||||
|
@ -266,7 +266,9 @@ def get_redis_client(**env_overrides):
|
||||||
return redis.Redis(**redis_kwargs)
|
return redis.Redis(**redis_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_redis_async_client(**env_overrides) -> async_redis.Redis:
|
def get_redis_async_client(
|
||||||
|
**env_overrides,
|
||||||
|
) -> async_redis.Redis:
|
||||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
||||||
|
|
|
@ -4,5 +4,6 @@ from .dual_cache import DualCache
|
||||||
from .in_memory_cache import InMemoryCache
|
from .in_memory_cache import InMemoryCache
|
||||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||||
from .redis_cache import RedisCache
|
from .redis_cache import RedisCache
|
||||||
|
from .redis_cluster_cache import RedisClusterCache
|
||||||
from .redis_semantic_cache import RedisSemanticCache
|
from .redis_semantic_cache import RedisSemanticCache
|
||||||
from .s3_cache import S3Cache
|
from .s3_cache import S3Cache
|
||||||
|
|
|
@ -41,6 +41,7 @@ from .dual_cache import DualCache # noqa
|
||||||
from .in_memory_cache import InMemoryCache
|
from .in_memory_cache import InMemoryCache
|
||||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||||
from .redis_cache import RedisCache
|
from .redis_cache import RedisCache
|
||||||
|
from .redis_cluster_cache import RedisClusterCache
|
||||||
from .redis_semantic_cache import RedisSemanticCache
|
from .redis_semantic_cache import RedisSemanticCache
|
||||||
from .s3_cache import S3Cache
|
from .s3_cache import S3Cache
|
||||||
|
|
||||||
|
@ -158,14 +159,23 @@ class Cache:
|
||||||
None. Cache is set as a litellm param
|
None. Cache is set as a litellm param
|
||||||
"""
|
"""
|
||||||
if type == LiteLLMCacheType.REDIS:
|
if type == LiteLLMCacheType.REDIS:
|
||||||
self.cache: BaseCache = RedisCache(
|
if redis_startup_nodes:
|
||||||
host=host,
|
self.cache: BaseCache = RedisClusterCache(
|
||||||
port=port,
|
host=host,
|
||||||
password=password,
|
port=port,
|
||||||
redis_flush_size=redis_flush_size,
|
password=password,
|
||||||
startup_nodes=redis_startup_nodes,
|
redis_flush_size=redis_flush_size,
|
||||||
**kwargs,
|
startup_nodes=redis_startup_nodes,
|
||||||
)
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cache = RedisCache(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
password=password,
|
||||||
|
redis_flush_size=redis_flush_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||||
self.cache = RedisSemanticCache(
|
self.cache = RedisSemanticCache(
|
||||||
host=host,
|
host=host,
|
||||||
|
|
|
@ -14,7 +14,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
@ -26,15 +26,20 @@ from .base_cache import BaseCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis, RedisCluster
|
||||||
from redis.asyncio.client import Pipeline
|
from redis.asyncio.client import Pipeline
|
||||||
|
from redis.asyncio.cluster import ClusterPipeline
|
||||||
|
|
||||||
pipeline = Pipeline
|
pipeline = Pipeline
|
||||||
|
cluster_pipeline = ClusterPipeline
|
||||||
async_redis_client = Redis
|
async_redis_client = Redis
|
||||||
|
async_redis_cluster_client = RedisCluster
|
||||||
Span = _Span
|
Span = _Span
|
||||||
else:
|
else:
|
||||||
pipeline = Any
|
pipeline = Any
|
||||||
|
cluster_pipeline = Any
|
||||||
async_redis_client = Any
|
async_redis_client = Any
|
||||||
|
async_redis_cluster_client = Any
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,7 +127,9 @@ class RedisCache(BaseCache):
|
||||||
else:
|
else:
|
||||||
super().__init__() # defaults to 60s
|
super().__init__() # defaults to 60s
|
||||||
|
|
||||||
def init_async_client(self):
|
def init_async_client(
|
||||||
|
self,
|
||||||
|
) -> Union[async_redis_client, async_redis_cluster_client]:
|
||||||
from .._redis import get_redis_async_client
|
from .._redis import get_redis_async_client
|
||||||
|
|
||||||
return get_redis_async_client(
|
return get_redis_async_client(
|
||||||
|
@ -345,8 +352,14 @@ class RedisCache(BaseCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _pipeline_helper(
|
async def _pipeline_helper(
|
||||||
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
|
self,
|
||||||
|
pipe: Union[pipeline, cluster_pipeline],
|
||||||
|
cache_list: List[Tuple[Any, Any]],
|
||||||
|
ttl: Optional[float],
|
||||||
) -> List:
|
) -> List:
|
||||||
|
"""
|
||||||
|
Helper function for executing a pipeline of set operations on Redis
|
||||||
|
"""
|
||||||
ttl = self.get_ttl(ttl=ttl)
|
ttl = self.get_ttl(ttl=ttl)
|
||||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||||
for cache_key, cache_value in cache_list:
|
for cache_key, cache_value in cache_list:
|
||||||
|
@ -359,7 +372,11 @@ class RedisCache(BaseCache):
|
||||||
_td: Optional[timedelta] = None
|
_td: Optional[timedelta] = None
|
||||||
if ttl is not None:
|
if ttl is not None:
|
||||||
_td = timedelta(seconds=ttl)
|
_td = timedelta(seconds=ttl)
|
||||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
pipe.set( # type: ignore
|
||||||
|
name=cache_key,
|
||||||
|
value=json_cache_value,
|
||||||
|
ex=_td,
|
||||||
|
)
|
||||||
# Execute the pipeline and return the results.
|
# Execute the pipeline and return the results.
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
return results
|
return results
|
||||||
|
@ -373,9 +390,8 @@ class RedisCache(BaseCache):
|
||||||
# don't waste a network request if there's nothing to set
|
# don't waste a network request if there's nothing to set
|
||||||
if len(cache_list) == 0:
|
if len(cache_list) == 0:
|
||||||
return
|
return
|
||||||
from redis.asyncio import Redis
|
|
||||||
|
|
||||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
_redis_client = self.init_async_client()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -384,7 +400,7 @@ class RedisCache(BaseCache):
|
||||||
cache_value: Any = None
|
cache_value: Any = None
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=False) as pipe:
|
||||||
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
||||||
|
|
||||||
print_verbose(f"pipeline results: {results}")
|
print_verbose(f"pipeline results: {results}")
|
||||||
|
@ -730,7 +746,8 @@ class RedisCache(BaseCache):
|
||||||
"""
|
"""
|
||||||
Use Redis for bulk read operations
|
Use Redis for bulk read operations
|
||||||
"""
|
"""
|
||||||
_redis_client = await self.init_async_client()
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
||||||
|
_redis_client: Any = self.init_async_client()
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
|
@ -822,7 +839,8 @@ class RedisCache(BaseCache):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def ping(self) -> bool:
|
async def ping(self) -> bool:
|
||||||
_redis_client = self.init_async_client()
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
|
||||||
|
_redis_client: Any = self.init_async_client()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
print_verbose("Pinging Async Redis Cache")
|
print_verbose("Pinging Async Redis Cache")
|
||||||
|
@ -858,7 +876,8 @@ class RedisCache(BaseCache):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def delete_cache_keys(self, keys):
|
async def delete_cache_keys(self, keys):
|
||||||
_redis_client = self.init_async_client()
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||||
|
_redis_client: Any = self.init_async_client()
|
||||||
# keys is a list, unpack it so it gets passed as individual elements to delete
|
# keys is a list, unpack it so it gets passed as individual elements to delete
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
await redis_client.delete(*keys)
|
await redis_client.delete(*keys)
|
||||||
|
@ -881,7 +900,8 @@ class RedisCache(BaseCache):
|
||||||
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
|
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
|
||||||
|
|
||||||
async def async_delete_cache(self, key: str):
|
async def async_delete_cache(self, key: str):
|
||||||
_redis_client = self.init_async_client()
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||||
|
_redis_client: Any = self.init_async_client()
|
||||||
# keys is str
|
# keys is str
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
await redis_client.delete(key)
|
await redis_client.delete(key)
|
||||||
|
@ -936,7 +956,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=False) as pipe:
|
||||||
results = await self._pipeline_increment_helper(
|
results = await self._pipeline_increment_helper(
|
||||||
pipe, increment_list
|
pipe, increment_list
|
||||||
)
|
)
|
||||||
|
@ -991,7 +1011,8 @@ class RedisCache(BaseCache):
|
||||||
Redis ref: https://redis.io/docs/latest/commands/ttl/
|
Redis ref: https://redis.io/docs/latest/commands/ttl/
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
_redis_client = await self.init_async_client()
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
|
||||||
|
_redis_client: Any = self.init_async_client()
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
ttl = await redis_client.ttl(key)
|
ttl = await redis_client.ttl(key)
|
||||||
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
|
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
|
||||||
|
|
44
litellm/caching/redis_cluster_cache.py
Normal file
44
litellm/caching/redis_cluster_cache.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
"""
|
||||||
|
Redis Cluster Cache implementation
|
||||||
|
|
||||||
|
Key differences:
|
||||||
|
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
from redis.asyncio import Redis, RedisCluster
|
||||||
|
from redis.asyncio.client import Pipeline
|
||||||
|
|
||||||
|
pipeline = Pipeline
|
||||||
|
async_redis_client = Redis
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
pipeline = Any
|
||||||
|
async_redis_client = Any
|
||||||
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClusterCache(RedisCache):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.redis_cluster_client: Optional[RedisCluster] = None
|
||||||
|
|
||||||
|
def init_async_client(self):
|
||||||
|
from redis.asyncio import RedisCluster
|
||||||
|
|
||||||
|
from .._redis import get_redis_async_client
|
||||||
|
|
||||||
|
if self.redis_cluster_client:
|
||||||
|
return self.redis_cluster_client
|
||||||
|
|
||||||
|
_redis_client = get_redis_async_client(
|
||||||
|
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||||
|
)
|
||||||
|
if isinstance(_redis_client, RedisCluster):
|
||||||
|
self.redis_cluster_client = _redis_client
|
||||||
|
return _redis_client
|
1
mypy.ini
1
mypy.ini
|
@ -1,6 +1,7 @@
|
||||||
[mypy]
|
[mypy]
|
||||||
warn_return_any = False
|
warn_return_any = False
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
mypy_path = litellm/stubs
|
||||||
|
|
||||||
[mypy-google.*]
|
[mypy-google.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
|
@ -21,7 +21,8 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import aembedding, completion, embedding
|
from litellm import aembedding, completion, embedding
|
||||||
from litellm.caching.caching import Cache
|
from litellm.caching.caching import Cache
|
||||||
|
from redis.asyncio import RedisCluster
|
||||||
|
from litellm.caching.redis_cluster_cache import RedisClusterCache
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock, call
|
from unittest.mock import AsyncMock, patch, MagicMock, call
|
||||||
import datetime
|
import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
@ -2328,8 +2329,12 @@ async def test_redis_caching_ttl_pipeline():
|
||||||
# Verify that the set method was called on the mock Redis instance
|
# Verify that the set method was called on the mock Redis instance
|
||||||
mock_set.assert_has_calls(
|
mock_set.assert_has_calls(
|
||||||
[
|
[
|
||||||
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
|
call.set(
|
||||||
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
|
name="test_key1", value='"test_value1"', ex=expected_timedelta
|
||||||
|
),
|
||||||
|
call.set(
|
||||||
|
name="test_key2", value='"test_value2"', ex=expected_timedelta
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2388,6 +2393,7 @@ async def test_redis_increment_pipeline():
|
||||||
from litellm.caching.redis_cache import RedisCache
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
litellm._turn_on_debug()
|
||||||
redis_cache = RedisCache(
|
redis_cache = RedisCache(
|
||||||
host=os.environ["REDIS_HOST"],
|
host=os.environ["REDIS_HOST"],
|
||||||
port=os.environ["REDIS_PORT"],
|
port=os.environ["REDIS_PORT"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue