From f2191ef4cb049bdb8a32d0846690f4a8519df143 Mon Sep 17 00:00:00 2001 From: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com> Date: Sat, 7 Sep 2024 21:49:02 +0200 Subject: [PATCH] fix(_redis.py): allow all supported arguments for redis cluster (#5554) --- litellm/_redis.py | 34 +++++++--------------------------- litellm/tests/test_caching.py | 2 +- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/litellm/_redis.py b/litellm/_redis.py index 23f82ed1a..135170e25 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -11,7 +11,7 @@ import inspect # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation import os -from typing import List, Optional +from typing import List import redis # type: ignore import redis.asyncio as async_redis # type: ignore @@ -55,19 +55,6 @@ def _get_redis_url_kwargs(client=None): return available_args -def _get_redis_cluster_kwargs(client=None): - if client is None: - client = redis.Redis.from_url - arg_spec = inspect.getfullargspec(redis.RedisCluster) - - # Only allow primitive arguments - exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"} - - available_args = [x for x in arg_spec.args if x not in exclude_args] - - return available_args - - def _get_redis_env_kwarg_mapping(): PREFIX = "REDIS_" @@ -143,13 +130,10 @@ def get_redis_client(**env_overrides): return redis.Redis.from_url(**url_kwargs) if "startup_nodes" in redis_kwargs: - from redis.cluster import ClusterNode + from redis.cluster import ClusterNode, cleanup_kwargs - args = _get_redis_cluster_kwargs() - cluster_kwargs = {} - for arg in redis_kwargs: - if arg in args: - cluster_kwargs[arg] = redis_kwargs[arg] + # Only allow primitive arguments + cluster_kwargs = cleanup_kwargs(**redis_kwargs) new_startup_nodes: List[ClusterNode] = [] @@ -177,14 +161,10 @@ def get_redis_async_client(**env_overrides): return async_redis.Redis.from_url(**url_kwargs) if "startup_nodes" in redis_kwargs: - from redis.cluster import ClusterNode - - args = _get_redis_cluster_kwargs() - cluster_kwargs = {} - for arg in redis_kwargs: - if arg in args: - cluster_kwargs[arg] = redis_kwargs[arg] + from redis.cluster import ClusterNode, cleanup_kwargs + # Only allow primitive arguments + cluster_kwargs = cleanup_kwargs(**redis_kwargs) new_startup_nodes: List[ClusterNode] = [] for item in redis_kwargs["startup_nodes"]: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index a272d2dcf..7719840ce 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -826,7 +826,7 @@ async def test_redis_cache_cluster_init_unit_test(): assert isinstance(resp.redis_client, RedisCluster) assert isinstance(resp.init_async_client(), AsyncRedisCluster) - resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes) + resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes, password="my-cluster-password") assert isinstance(resp.cache, RedisCache) assert isinstance(resp.cache.redis_client, RedisCluster)