fix(_redis.py): allow all supported arguments for redis cluster (#5554)

This commit is contained in:
Jonas Dittrich 2024-09-07 21:49:02 +02:00 committed by GitHub
parent 4bf893afe1
commit f2191ef4cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 28 deletions

View file

@ -11,7 +11,7 @@ import inspect
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
from typing import List, Optional
from typing import List
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
@ -55,19 +55,6 @@ def _get_redis_url_kwargs(client=None):
return available_args
def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)
# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
@ -143,13 +130,10 @@ def get_redis_client(**env_overrides):
return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
from redis.cluster import ClusterNode, cleanup_kwargs
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
# Only allow primitive arguments
cluster_kwargs = cleanup_kwargs(**redis_kwargs)
new_startup_nodes: List[ClusterNode] = []
@ -177,14 +161,10 @@ def get_redis_async_client(**env_overrides):
return async_redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
from redis.cluster import ClusterNode, cleanup_kwargs
# Only allow primitive arguments
cluster_kwargs = cleanup_kwargs(**redis_kwargs)
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:

View file

@ -826,7 +826,7 @@ async def test_redis_cache_cluster_init_unit_test():
assert isinstance(resp.redis_client, RedisCluster)
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes, password="my-cluster-password")
assert isinstance(resp.cache, RedisCache)
assert isinstance(resp.cache.redis_client, RedisCluster)