fix allow using .env vars for redis cluster

This commit is contained in:
Ishaan Jaff 2024-09-07 08:54:40 -07:00
parent 9225d31776
commit 5c4f3a9a34

View file

@ -19,6 +19,8 @@ import redis.asyncio as async_redis # type: ignore
import litellm
from ._logging import verbose_logger
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@ -121,17 +123,56 @@ def _get_redis_client_logic(**env_overrides):
**env_overrides,
}
_startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret(
"REDIS_CLUSTER_NODES"
)
if _startup_nodes is not None:
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
pass
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
_redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
if _redis_cluster_nodes_in_env is not None:
try:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
except json.JSONDecodeError:
raise ValueError(
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
)
verbose_logger.debug(
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
)
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
@ -147,24 +188,8 @@ def get_redis_client(**env_overrides):
"startup_nodes" in redis_kwargs
or litellm.get_secret("REDIS_CLUSTER_NODES") is not None
):
_redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
if _redis_cluster_nodes_in_env is not None:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
return init_redis_cluster(redis_kwargs)
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.Redis(**redis_kwargs)