Merge pull request #5579 from BerriAI/litellm_set_redis_cluster_env

[Feat] Allow setting up Redis Cluster using .env vars
This commit is contained in:
Ishaan Jaff 2024-09-07 11:31:38 -07:00 committed by GitHub
commit 009a1f7f86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 139 additions and 15 deletions

View file

@ -54,6 +54,10 @@ litellm_caching:<hash>
#### Redis Cluster
<Tabs>
<TabItem value="redis-cluster-config" label="Set on config.yaml">
```yaml
model_list:
- model_name: "*"
@ -68,6 +72,44 @@ litellm_settings:
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}]
```
</TabItem>
<TabItem value="redis-env" label="Set on .env">
You can configure redis cluster in your .env by setting `REDIS_CLUSTER_NODES` in your .env
**Example `REDIS_CLUSTER_NODES`** value
```
REDIS_CLUSTER_NODES = "[{"host": "127.0.0.1", "port": "7001"}, {"host": "127.0.0.1", "port": "7003"}, {"host": "127.0.0.1", "port": "7004"}, {"host": "127.0.0.1", "port": "7005"}, {"host": "127.0.0.1", "port": "7006"}, {"host": "127.0.0.1", "port": "7007"}]"
```
:::note
Example python script for setting redis cluster nodes in .env:
```python
# List of startup nodes
startup_nodes = [
{"host": "127.0.0.1", "port": "7001"},
{"host": "127.0.0.1", "port": "7003"},
{"host": "127.0.0.1", "port": "7004"},
{"host": "127.0.0.1", "port": "7005"},
{"host": "127.0.0.1", "port": "7006"},
{"host": "127.0.0.1", "port": "7007"},
]
# set startup nodes in environment variables
os.environ["REDIS_CLUSTER_NODES"] = json.dumps(startup_nodes)
print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
```
:::
</TabItem>
</Tabs>
#### TTL
```yaml

View file

@ -8,6 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan
import inspect
import json
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
@ -18,6 +19,8 @@ import redis.asyncio as async_redis # type: ignore
import litellm
from ._logging import verbose_logger
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@ -64,6 +67,7 @@ def _get_redis_cluster_kwargs(client=None):
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
available_args.append("password")
return available_args
@ -120,17 +124,56 @@ def _get_redis_client_logic(**env_overrides):
**env_overrides,
}
_startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret(
"REDIS_CLUSTER_NODES"
)
if _startup_nodes is not None:
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
pass
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
_redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
if _redis_cluster_nodes_in_env is not None:
try:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
except json.JSONDecodeError:
raise ValueError(
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
)
verbose_logger.debug(
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
)
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
@ -142,21 +185,12 @@ def get_redis_client(**env_overrides):
return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
if (
"startup_nodes" in redis_kwargs
or litellm.get_secret("REDIS_CLUSTER_NODES") is not None
):
return init_redis_cluster(redis_kwargs)
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.Redis(**redis_kwargs)

View file

@ -290,7 +290,7 @@ async def test_langfuse_logging_audio_transcriptions(langfuse_client):
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.flaky(retries=5, delay=1)
async def test_langfuse_masked_input_output(langfuse_client):
"""
Test that creates a trace with masked input and output

View file

@ -837,6 +837,54 @@ async def test_redis_cache_cluster_init_unit_test():
raise e
@pytest.mark.asyncio
@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
async def test_redis_cache_cluster_init_with_env_vars_unit_test():
try:
import json
from redis.asyncio import RedisCluster as AsyncRedisCluster
from redis.cluster import RedisCluster
from litellm.caching import RedisCache
litellm.set_verbose = True
# List of startup nodes
startup_nodes = [
{"host": "127.0.0.1", "port": "7001"},
{"host": "127.0.0.1", "port": "7003"},
{"host": "127.0.0.1", "port": "7004"},
{"host": "127.0.0.1", "port": "7005"},
{"host": "127.0.0.1", "port": "7006"},
{"host": "127.0.0.1", "port": "7007"},
]
# set startup nodes in environment variables
os.environ["REDIS_CLUSTER_NODES"] = json.dumps(startup_nodes)
print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
# unser REDIS_HOST, REDIS_PORT, REDIS_PASSWORD
os.environ.pop("REDIS_HOST", None)
os.environ.pop("REDIS_PORT", None)
os.environ.pop("REDIS_PASSWORD", None)
resp = RedisCache()
print("response from redis cache", resp)
assert isinstance(resp.redis_client, RedisCluster)
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
resp = litellm.Cache(type="redis")
assert isinstance(resp.cache, RedisCache)
assert isinstance(resp.cache.redis_client, RedisCluster)
assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
except Exception as e:
print(f"{str(e)}\n\n{traceback.format_exc()}")
raise e
@pytest.mark.asyncio
async def test_redis_cache_acompletion_stream():
try: