feat(caching.py): redis cluster support

Closes https://github.com/BerriAI/litellm/issues/4358
This commit is contained in:
Krrish Dholakia 2024-08-21 15:01:52 -07:00
parent ac5c6c8751
commit 33c9c16388
5 changed files with 106 additions and 7 deletions

View file

@ -7,13 +7,17 @@
# #
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import inspect
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os import os
import inspect
import redis, litellm # type: ignore
import redis.asyncio as async_redis # type: ignore
from typing import List, Optional from typing import List, Optional
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
import litellm
def _get_redis_kwargs(): def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis) arg_spec = inspect.getfullargspec(redis.Redis)
@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None):
return available_args return available_args
def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)
# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
return available_args
def _get_redis_env_kwarg_mapping(): def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_" PREFIX = "REDIS_"
@ -124,6 +141,22 @@ def get_redis_client(**env_overrides):
url_kwargs[arg] = redis_kwargs[arg] url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs) return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)
@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides):
) )
return async_redis.Redis.from_url(**url_kwargs) return async_redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return async_redis.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs
)
return async_redis.Redis( return async_redis.Redis(
socket_timeout=5, socket_timeout=5,
**redis_kwargs, **redis_kwargs,
@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides):
connection_class = async_redis.SSLConnection connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None) redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class redis_kwargs["connection_class"] = connection_class
redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -203,6 +203,7 @@ class RedisCache(BaseCache):
password=None, password=None,
redis_flush_size=100, redis_flush_size=100,
namespace: Optional[str] = None, namespace: Optional[str] = None,
startup_nodes: Optional[List] = None, # for redis-cluster
**kwargs, **kwargs,
): ):
import redis import redis
@ -218,7 +219,8 @@ class RedisCache(BaseCache):
redis_kwargs["port"] = port redis_kwargs["port"] = port
if password is not None: if password is not None:
redis_kwargs["password"] = password redis_kwargs["password"] = password
if startup_nodes is not None:
redis_kwargs["startup_nodes"] = startup_nodes
### HEALTH MONITORING OBJECT ### ### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance( if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging kwargs["service_logger_obj"], ServiceLogging
@ -246,7 +248,7 @@ class RedisCache(BaseCache):
### ASYNC HEALTH PING ### ### ASYNC HEALTH PING ###
try: try:
# asyncio.get_running_loop().create_task(self.ping()) # asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping()) _ = asyncio.get_running_loop().create_task(self.ping())
except Exception as e: except Exception as e:
if "no running event loop" in str(e): if "no running event loop" in str(e):
verbose_logger.debug( verbose_logger.debug(
@ -2123,6 +2125,7 @@ class Cache:
redis_semantic_cache_use_async=False, redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None, redis_flush_size=None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir=None, disk_cache_dir=None,
qdrant_api_base: Optional[str] = None, qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None, qdrant_api_key: Optional[str] = None,
@ -2155,7 +2158,12 @@ class Cache:
""" """
if type == "redis": if type == "redis":
self.cache: BaseCache = RedisCache( self.cache: BaseCache = RedisCache(
host, port, password, redis_flush_size, **kwargs host,
port,
password,
redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
) )
elif type == "redis-semantic": elif type == "redis-semantic":
self.cache = RedisSemanticCache( self.cache = RedisSemanticCache(

View file

@ -2,3 +2,10 @@ model_list:
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: "*" model: "*"
litellm_settings:
cache: True
cache_params:
type: redis
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}]

View file

@ -1583,7 +1583,7 @@ class ProxyConfig:
verbose_proxy_logger.debug( # noqa verbose_proxy_logger.debug( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
) )
elif key == "cache" and value == False: elif key == "cache" and value is False:
pass pass
elif key == "guardrails": elif key == "guardrails":
if premium_user is not True: if premium_user is not True:

View file

@ -804,6 +804,38 @@ def test_redis_cache_completion_stream():
# test_redis_cache_completion_stream() # test_redis_cache_completion_stream()
# @pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
@pytest.mark.asyncio
async def test_redis_cache_cluster_init_unit_test():
try:
from redis.asyncio import RedisCluster as AsyncRedisCluster
from redis.cluster import RedisCluster
from litellm.caching import RedisCache
litellm.set_verbose = True
# List of startup nodes
startup_nodes = [
{"host": "127.0.0.1", "port": "7001"},
]
resp = RedisCache(startup_nodes=startup_nodes)
assert isinstance(resp.redis_client, RedisCluster)
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
assert isinstance(resp.cache, RedisCache)
assert isinstance(resp.cache.redis_client, RedisCluster)
assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
except Exception as e:
print(f"{str(e)}\n\n{traceback.format_exc()}")
raise e
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_cache_acompletion_stream(): async def test_redis_cache_acompletion_stream():
try: try: