diff --git a/litellm/_redis.py b/litellm/_redis.py index 70c38cf7f5..1e03993c20 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster: ) verbose_logger.debug( - "init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"] + "init_redis_cluster: startup nodes are being initialized." ) from redis.cluster import ClusterNode @@ -266,7 +266,9 @@ def get_redis_client(**env_overrides): return redis.Redis(**redis_kwargs) -def get_redis_async_client(**env_overrides) -> async_redis.Redis: +def get_redis_async_client( + **env_overrides, +) -> async_redis.Redis: redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: args = _get_redis_url_kwargs(client=async_redis.Redis.from_url) diff --git a/litellm/caching/__init__.py b/litellm/caching/__init__.py index f10675f5e0..e10d01ff02 100644 --- a/litellm/caching/__init__.py +++ b/litellm/caching/__init__.py @@ -4,5 +4,6 @@ from .dual_cache import DualCache from .in_memory_cache import InMemoryCache from .qdrant_semantic_cache import QdrantSemanticCache from .redis_cache import RedisCache +from .redis_cluster_cache import RedisClusterCache from .redis_semantic_cache import RedisSemanticCache from .s3_cache import S3Cache diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index f7842ad48a..90e37b07db 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -41,6 +41,7 @@ from .dual_cache import DualCache # noqa from .in_memory_cache import InMemoryCache from .qdrant_semantic_cache import QdrantSemanticCache from .redis_cache import RedisCache +from .redis_cluster_cache import RedisClusterCache from .redis_semantic_cache import RedisSemanticCache from .s3_cache import S3Cache @@ -158,14 +159,23 @@ class Cache: None. Cache is set as a litellm param """ if type == LiteLLMCacheType.REDIS: - self.cache: BaseCache = RedisCache( - host=host, - port=port, - password=password, - redis_flush_size=redis_flush_size, - startup_nodes=redis_startup_nodes, - **kwargs, - ) + if redis_startup_nodes: + self.cache: BaseCache = RedisClusterCache( + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, + startup_nodes=redis_startup_nodes, + **kwargs, + ) + else: + self.cache = RedisCache( + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, + **kwargs, + ) elif type == LiteLLMCacheType.REDIS_SEMANTIC: self.cache = RedisSemanticCache( host=host, diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 21455fa7f2..650bb7955c 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -14,7 +14,7 @@ import inspect import json import time from datetime import timedelta -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import litellm from litellm._logging import print_verbose, verbose_logger @@ -26,15 +26,20 @@ from .base_cache import BaseCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - from redis.asyncio import Redis + from redis.asyncio import Redis, RedisCluster from redis.asyncio.client import Pipeline + from redis.asyncio.cluster import ClusterPipeline pipeline = Pipeline + cluster_pipeline = ClusterPipeline async_redis_client = Redis + async_redis_cluster_client = RedisCluster Span = _Span else: pipeline = Any + cluster_pipeline = Any async_redis_client = Any + async_redis_cluster_client = Any Span = Any @@ -122,7 +127,9 @@ class RedisCache(BaseCache): else: super().__init__() # defaults to 60s - def init_async_client(self): + def init_async_client( + self, + ) -> Union[async_redis_client, async_redis_cluster_client]: from .._redis import get_redis_async_client return get_redis_async_client( @@ -345,8 +352,14 @@ class RedisCache(BaseCache): ) async def _pipeline_helper( - self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] + self, + pipe: Union[pipeline, cluster_pipeline], + cache_list: List[Tuple[Any, Any]], + ttl: Optional[float], ) -> List: + """ + Helper function for executing a pipeline of set operations on Redis + """ ttl = self.get_ttl(ttl=ttl) # Iterate through each key-value pair in the cache_list and set them in the pipeline. for cache_key, cache_value in cache_list: @@ -359,7 +372,11 @@ class RedisCache(BaseCache): _td: Optional[timedelta] = None if ttl is not None: _td = timedelta(seconds=ttl) - pipe.set(cache_key, json_cache_value, ex=_td) + pipe.set( # type: ignore + name=cache_key, + value=json_cache_value, + ex=_td, + ) # Execute the pipeline and return the results. results = await pipe.execute() return results @@ -373,9 +390,8 @@ class RedisCache(BaseCache): # don't waste a network request if there's nothing to set if len(cache_list) == 0: return - from redis.asyncio import Redis - _redis_client: Redis = self.init_async_client() # type: ignore + _redis_client = self.init_async_client() start_time = time.time() print_verbose( @@ -384,7 +400,7 @@ class RedisCache(BaseCache): cache_value: Any = None try: async with _redis_client as redis_client: - async with redis_client.pipeline(transaction=True) as pipe: + async with redis_client.pipeline(transaction=False) as pipe: results = await self._pipeline_helper(pipe, cache_list, ttl) print_verbose(f"pipeline results: {results}") @@ -730,7 +746,8 @@ class RedisCache(BaseCache): """ Use Redis for bulk read operations """ - _redis_client = await self.init_async_client() + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget` + _redis_client: Any = self.init_async_client() key_value_dict = {} start_time = time.time() try: @@ -822,7 +839,8 @@ class RedisCache(BaseCache): raise e async def ping(self) -> bool: - _redis_client = self.init_async_client() + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping` + _redis_client: Any = self.init_async_client() start_time = time.time() async with _redis_client as redis_client: print_verbose("Pinging Async Redis Cache") @@ -858,7 +876,8 @@ class RedisCache(BaseCache): raise e async def delete_cache_keys(self, keys): - _redis_client = self.init_async_client() + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete` + _redis_client: Any = self.init_async_client() # keys is a list, unpack it so it gets passed as individual elements to delete async with _redis_client as redis_client: await redis_client.delete(*keys) @@ -881,7 +900,8 @@ class RedisCache(BaseCache): await self.async_redis_conn_pool.disconnect(inuse_connections=True) async def async_delete_cache(self, key: str): - _redis_client = self.init_async_client() + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete` + _redis_client: Any = self.init_async_client() # keys is str async with _redis_client as redis_client: await redis_client.delete(key) @@ -936,7 +956,7 @@ class RedisCache(BaseCache): try: async with _redis_client as redis_client: - async with redis_client.pipeline(transaction=True) as pipe: + async with redis_client.pipeline(transaction=False) as pipe: results = await self._pipeline_increment_helper( pipe, increment_list ) @@ -991,7 +1011,8 @@ class RedisCache(BaseCache): Redis ref: https://redis.io/docs/latest/commands/ttl/ """ try: - _redis_client = await self.init_async_client() + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl` + _redis_client: Any = self.init_async_client() async with _redis_client as redis_client: ttl = await redis_client.ttl(key) if ttl <= -1: # -1 means the key does not exist, -2 key does not exist diff --git a/litellm/caching/redis_cluster_cache.py b/litellm/caching/redis_cluster_cache.py new file mode 100644 index 0000000000..397ea89790 --- /dev/null +++ b/litellm/caching/redis_cluster_cache.py @@ -0,0 +1,44 @@ +""" +Redis Cluster Cache implementation + +Key differences: +- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created +""" + +from typing import TYPE_CHECKING, Any, Optional + +from litellm.caching.redis_cache import RedisCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + from redis.asyncio import Redis, RedisCluster + from redis.asyncio.client import Pipeline + + pipeline = Pipeline + async_redis_client = Redis + Span = _Span +else: + pipeline = Any + async_redis_client = Any + Span = Any + + +class RedisClusterCache(RedisCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.redis_cluster_client: Optional[RedisCluster] = None + + def init_async_client(self): + from redis.asyncio import RedisCluster + + from .._redis import get_redis_async_client + + if self.redis_cluster_client: + return self.redis_cluster_client + + _redis_client = get_redis_async_client( + connection_pool=self.async_redis_conn_pool, **self.redis_kwargs + ) + if isinstance(_redis_client, RedisCluster): + self.redis_cluster_client = _redis_client + return _redis_client diff --git a/mypy.ini b/mypy.ini index 82560ef184..19ead3ba7d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,7 @@ [mypy] warn_return_any = False ignore_missing_imports = True +mypy_path = litellm/stubs [mypy-google.*] ignore_missing_imports = True diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index a8452249e9..04110dae4e 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -21,7 +21,8 @@ import pytest import litellm from litellm import aembedding, completion, embedding from litellm.caching.caching import Cache - +from redis.asyncio import RedisCluster +from litellm.caching.redis_cluster_cache import RedisClusterCache from unittest.mock import AsyncMock, patch, MagicMock, call import datetime from datetime import timedelta @@ -2328,8 +2329,12 @@ async def test_redis_caching_ttl_pipeline(): # Verify that the set method was called on the mock Redis instance mock_set.assert_has_calls( [ - call.set("test_key1", '"test_value1"', ex=expected_timedelta), - call.set("test_key2", '"test_value2"', ex=expected_timedelta), + call.set( + name="test_key1", value='"test_value1"', ex=expected_timedelta + ), + call.set( + name="test_key2", value='"test_value2"', ex=expected_timedelta + ), ] ) @@ -2388,6 +2393,7 @@ async def test_redis_increment_pipeline(): from litellm.caching.redis_cache import RedisCache litellm.set_verbose = True + litellm._turn_on_debug() redis_cache = RedisCache( host=os.environ["REDIS_HOST"], port=os.environ["REDIS_PORT"],