diff --git a/litellm/_redis.py b/litellm/_redis.py index 23f82ed1a..135170e25 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -11,7 +11,7 @@ import inspect # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation import os -from typing import List, Optional +from typing import List import redis # type: ignore import redis.asyncio as async_redis # type: ignore @@ -55,19 +55,6 @@ def _get_redis_url_kwargs(client=None): return available_args -def _get_redis_cluster_kwargs(client=None): - if client is None: - client = redis.Redis.from_url - arg_spec = inspect.getfullargspec(redis.RedisCluster) - - # Only allow primitive arguments - exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"} - - available_args = [x for x in arg_spec.args if x not in exclude_args] - - return available_args - - def _get_redis_env_kwarg_mapping(): PREFIX = "REDIS_" @@ -143,13 +130,10 @@ def get_redis_client(**env_overrides): return redis.Redis.from_url(**url_kwargs) if "startup_nodes" in redis_kwargs: - from redis.cluster import ClusterNode + from redis.cluster import ClusterNode, cleanup_kwargs - args = _get_redis_cluster_kwargs() - cluster_kwargs = {} - for arg in redis_kwargs: - if arg in args: - cluster_kwargs[arg] = redis_kwargs[arg] + # Only allow primitive arguments + cluster_kwargs = cleanup_kwargs(**redis_kwargs) new_startup_nodes: List[ClusterNode] = [] @@ -177,14 +161,10 @@ def get_redis_async_client(**env_overrides): return async_redis.Redis.from_url(**url_kwargs) if "startup_nodes" in redis_kwargs: - from redis.cluster import ClusterNode - - args = _get_redis_cluster_kwargs() - cluster_kwargs = {} - for arg in redis_kwargs: - if arg in args: - cluster_kwargs[arg] = redis_kwargs[arg] + from redis.cluster import ClusterNode, cleanup_kwargs + # Only allow primitive arguments + cluster_kwargs = cleanup_kwargs(**redis_kwargs) new_startup_nodes: List[ClusterNode] = [] for item in redis_kwargs["startup_nodes"]: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index a272d2dcf..7719840ce 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -826,7 +826,7 @@ async def test_redis_cache_cluster_init_unit_test(): assert isinstance(resp.redis_client, RedisCluster) assert isinstance(resp.init_async_client(), AsyncRedisCluster) - resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes) + resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes, password="my-cluster-password") assert isinstance(resp.cache, RedisCache) assert isinstance(resp.cache.redis_client, RedisCluster)