diff --git a/litellm/_redis.py b/litellm/_redis.py index 1c300afcb..8d0fa4e18 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -19,6 +19,8 @@ import redis.asyncio as async_redis # type: ignore import litellm +from ._logging import verbose_logger + def _get_redis_kwargs(): arg_spec = inspect.getfullargspec(redis.Redis) @@ -121,17 +123,56 @@ def _get_redis_client_logic(**env_overrides): **env_overrides, } + _startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret( + "REDIS_CLUSTER_NODES" + ) + + if _startup_nodes is not None: + redis_kwargs["startup_nodes"] = json.loads(_startup_nodes) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: redis_kwargs.pop("host", None) redis_kwargs.pop("port", None) redis_kwargs.pop("db", None) redis_kwargs.pop("password", None) + elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None: + pass elif "host" not in redis_kwargs or redis_kwargs["host"] is None: raise ValueError("Either 'host' or 'url' must be specified for redis.") # litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") return redis_kwargs +def init_redis_cluster(redis_kwargs) -> redis.RedisCluster: + _redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES") + if _redis_cluster_nodes_in_env is not None: + try: + redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env) + except json.JSONDecodeError: + raise ValueError( + "REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted." + ) + + verbose_logger.debug( + "init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"] + ) + from redis.cluster import ClusterNode + + args = _get_redis_cluster_kwargs() + cluster_kwargs = {} + for arg in redis_kwargs: + if arg in args: + cluster_kwargs[arg] = redis_kwargs[arg] + + new_startup_nodes: List[ClusterNode] = [] + + for item in redis_kwargs["startup_nodes"]: + new_startup_nodes.append(ClusterNode(**item)) + + redis_kwargs.pop("startup_nodes") + return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) + + def get_redis_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: @@ -147,24 +188,8 @@ def get_redis_client(**env_overrides): "startup_nodes" in redis_kwargs or litellm.get_secret("REDIS_CLUSTER_NODES") is not None ): - _redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES") - if _redis_cluster_nodes_in_env is not None: - redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env) + return init_redis_cluster(redis_kwargs) - from redis.cluster import ClusterNode - - args = _get_redis_cluster_kwargs() - cluster_kwargs = {} - for arg in redis_kwargs: - if arg in args: - cluster_kwargs[arg] = redis_kwargs[arg] - - new_startup_nodes: List[ClusterNode] = [] - - for item in redis_kwargs["startup_nodes"]: - new_startup_nodes.append(ClusterNode(**item)) - redis_kwargs.pop("startup_nodes") - return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) return redis.Redis(**redis_kwargs)