feat(caching.py): redis cluster support

Closes https://github.com/BerriAI/litellm/issues/4358
This commit is contained in:
Krrish Dholakia 2024-08-21 15:01:52 -07:00
parent ac5c6c8751
commit 33c9c16388
5 changed files with 106 additions and 7 deletions

View file

@ -7,13 +7,17 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import inspect
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
import inspect
import redis, litellm # type: ignore
import redis.asyncio as async_redis # type: ignore
from typing import List, Optional
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
import litellm
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None):
return available_args
def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)
# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
@ -124,6 +141,22 @@ def get_redis_client(**env_overrides):
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.Redis(**redis_kwargs)
@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides):
)
return async_redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return async_redis.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs
)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides):
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)