diff --git a/litellm/_redis.py b/litellm/_redis.py index d72016dcd9..23f82ed1a7 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -7,13 +7,17 @@ # # Thank you users! We ❤️ you! - Krrish & Ishaan +import inspect + # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation import os -import inspect -import redis, litellm # type: ignore -import redis.asyncio as async_redis # type: ignore from typing import List, Optional +import redis # type: ignore +import redis.asyncio as async_redis # type: ignore + +import litellm + def _get_redis_kwargs(): arg_spec = inspect.getfullargspec(redis.Redis) @@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None): return available_args +def _get_redis_cluster_kwargs(client=None): + if client is None: + client = redis.Redis.from_url + arg_spec = inspect.getfullargspec(redis.RedisCluster) + + # Only allow primitive arguments + exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"} + + available_args = [x for x in arg_spec.args if x not in exclude_args] + + return available_args + + def _get_redis_env_kwarg_mapping(): PREFIX = "REDIS_" @@ -124,6 +141,22 @@ def get_redis_client(**env_overrides): url_kwargs[arg] = redis_kwargs[arg] return redis.Redis.from_url(**url_kwargs) + + if "startup_nodes" in redis_kwargs: + from redis.cluster import ClusterNode + + args = _get_redis_cluster_kwargs() + cluster_kwargs = {} + for arg in redis_kwargs: + if arg in args: + cluster_kwargs[arg] = redis_kwargs[arg] + + new_startup_nodes: List[ClusterNode] = [] + + for item in redis_kwargs["startup_nodes"]: + new_startup_nodes.append(ClusterNode(**item)) + redis_kwargs.pop("startup_nodes") + return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) return redis.Redis(**redis_kwargs) @@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides): ) return async_redis.Redis.from_url(**url_kwargs) + if "startup_nodes" in redis_kwargs: + from redis.cluster import ClusterNode + + args = _get_redis_cluster_kwargs() + cluster_kwargs = {} + for arg in redis_kwargs: + if arg in args: + cluster_kwargs[arg] = redis_kwargs[arg] + + new_startup_nodes: List[ClusterNode] = [] + + for item in redis_kwargs["startup_nodes"]: + new_startup_nodes.append(ClusterNode(**item)) + redis_kwargs.pop("startup_nodes") + return async_redis.RedisCluster( + startup_nodes=new_startup_nodes, **cluster_kwargs + ) + return async_redis.Redis( socket_timeout=5, **redis_kwargs, @@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides): connection_class = async_redis.SSLConnection redis_kwargs.pop("ssl", None) redis_kwargs["connection_class"] = connection_class + redis_kwargs.pop("startup_nodes", None) return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index 1c72160295..1b19fdf3e5 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -203,6 +203,7 @@ class RedisCache(BaseCache): password=None, redis_flush_size=100, namespace: Optional[str] = None, + startup_nodes: Optional[List] = None, # for redis-cluster **kwargs, ): import redis @@ -218,7 +219,8 @@ class RedisCache(BaseCache): redis_kwargs["port"] = port if password is not None: redis_kwargs["password"] = password - + if startup_nodes is not None: + redis_kwargs["startup_nodes"] = startup_nodes ### HEALTH MONITORING OBJECT ### if kwargs.get("service_logger_obj", None) is not None and isinstance( kwargs["service_logger_obj"], ServiceLogging @@ -246,7 +248,7 @@ class RedisCache(BaseCache): ### ASYNC HEALTH PING ### try: # asyncio.get_running_loop().create_task(self.ping()) - result = asyncio.get_running_loop().create_task(self.ping()) + _ = asyncio.get_running_loop().create_task(self.ping()) except Exception as e: if "no running event loop" in str(e): verbose_logger.debug( @@ -2123,6 +2125,7 @@ class Cache: redis_semantic_cache_use_async=False, redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_flush_size=None, + redis_startup_nodes: Optional[List] = None, disk_cache_dir=None, qdrant_api_base: Optional[str] = None, qdrant_api_key: Optional[str] = None, @@ -2155,7 +2158,12 @@ class Cache: """ if type == "redis": self.cache: BaseCache = RedisCache( - host, port, password, redis_flush_size, **kwargs + host, + port, + password, + redis_flush_size, + startup_nodes=redis_startup_nodes, + **kwargs, ) elif type == "redis-semantic": self.cache = RedisSemanticCache( diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 96a0242a8e..10d608ec89 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -2,3 +2,10 @@ model_list: - model_name: "*" litellm_params: model: "*" + + +litellm_settings: + cache: True + cache_params: + type: redis + redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a9d0325d80..8986b587b7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1583,7 +1583,7 @@ class ProxyConfig: verbose_proxy_logger.debug( # noqa f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" ) - elif key == "cache" and value == False: + elif key == "cache" and value is False: pass elif key == "guardrails": if premium_user is not True: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 64196e5c56..5da883f4ae 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -804,6 +804,38 @@ def test_redis_cache_completion_stream(): # test_redis_cache_completion_stream() +# @pytest.mark.skip(reason="Local test. Requires running redis cluster locally.") +@pytest.mark.asyncio +async def test_redis_cache_cluster_init_unit_test(): + try: + from redis.asyncio import RedisCluster as AsyncRedisCluster + from redis.cluster import RedisCluster + + from litellm.caching import RedisCache + + litellm.set_verbose = True + + # List of startup nodes + startup_nodes = [ + {"host": "127.0.0.1", "port": "7001"}, + ] + + resp = RedisCache(startup_nodes=startup_nodes) + + assert isinstance(resp.redis_client, RedisCluster) + assert isinstance(resp.init_async_client(), AsyncRedisCluster) + + resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes) + + assert isinstance(resp.cache, RedisCache) + assert isinstance(resp.cache.redis_client, RedisCluster) + assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster) + + except Exception as e: + print(f"{str(e)}\n\n{traceback.format_exc()}") + raise e + + @pytest.mark.asyncio async def test_redis_cache_acompletion_stream(): try: