forked from phoenix/litellm-mirror
Merge pull request #5579 from BerriAI/litellm_set_redis_cluster_env
[Feat] Allow setting up Redis Cluster using .env vars
This commit is contained in:
commit
009a1f7f86
4 changed files with 139 additions and 15 deletions
|
@ -54,6 +54,10 @@ litellm_caching:<hash>
|
|||
|
||||
#### Redis Cluster
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="redis-cluster-config" label="Set on config.yaml">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "*"
|
||||
|
@ -68,6 +72,44 @@ litellm_settings:
|
|||
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}]
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="redis-env" label="Set on .env">
|
||||
|
||||
You can configure redis cluster in your .env by setting `REDIS_CLUSTER_NODES` in your .env
|
||||
|
||||
**Example `REDIS_CLUSTER_NODES`** value
|
||||
|
||||
```
|
||||
REDIS_CLUSTER_NODES = "[{"host": "127.0.0.1", "port": "7001"}, {"host": "127.0.0.1", "port": "7003"}, {"host": "127.0.0.1", "port": "7004"}, {"host": "127.0.0.1", "port": "7005"}, {"host": "127.0.0.1", "port": "7006"}, {"host": "127.0.0.1", "port": "7007"}]"
|
||||
```
|
||||
|
||||
:::note
|
||||
|
||||
Example python script for setting redis cluster nodes in .env:
|
||||
|
||||
```python
|
||||
# List of startup nodes
|
||||
startup_nodes = [
|
||||
{"host": "127.0.0.1", "port": "7001"},
|
||||
{"host": "127.0.0.1", "port": "7003"},
|
||||
{"host": "127.0.0.1", "port": "7004"},
|
||||
{"host": "127.0.0.1", "port": "7005"},
|
||||
{"host": "127.0.0.1", "port": "7006"},
|
||||
{"host": "127.0.0.1", "port": "7007"},
|
||||
]
|
||||
|
||||
# set startup nodes in environment variables
|
||||
os.environ["REDIS_CLUSTER_NODES"] = json.dumps(startup_nodes)
|
||||
print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
#### TTL
|
||||
|
||||
```yaml
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
|
@ -18,6 +19,8 @@ import redis.asyncio as async_redis # type: ignore
|
|||
|
||||
import litellm
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
@ -64,6 +67,7 @@ def _get_redis_cluster_kwargs(client=None):
|
|||
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
available_args.append("password")
|
||||
|
||||
return available_args
|
||||
|
||||
|
@ -120,17 +124,56 @@ def _get_redis_client_logic(**env_overrides):
|
|||
**env_overrides,
|
||||
}
|
||||
|
||||
_startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret(
|
||||
"REDIS_CLUSTER_NODES"
|
||||
)
|
||||
|
||||
if _startup_nodes is not None:
|
||||
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
|
||||
pass
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
||||
_redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
|
||||
if _redis_cluster_nodes_in_env is not None:
|
||||
try:
|
||||
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
|
||||
)
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
|
@ -142,21 +185,12 @@ def get_redis_client(**env_overrides):
|
|||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs:
|
||||
from redis.cluster import ClusterNode
|
||||
if (
|
||||
"startup_nodes" in redis_kwargs
|
||||
or litellm.get_secret("REDIS_CLUSTER_NODES") is not None
|
||||
):
|
||||
return init_redis_cluster(redis_kwargs)
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -290,7 +290,7 @@ async def test_langfuse_logging_audio_transcriptions(langfuse_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.flaky(retries=5, delay=1)
|
||||
async def test_langfuse_masked_input_output(langfuse_client):
|
||||
"""
|
||||
Test that creates a trace with masked input and output
|
||||
|
|
|
@ -837,6 +837,54 @@ async def test_redis_cache_cluster_init_unit_test():
|
|||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
|
||||
async def test_redis_cache_cluster_init_with_env_vars_unit_test():
|
||||
try:
|
||||
import json
|
||||
|
||||
from redis.asyncio import RedisCluster as AsyncRedisCluster
|
||||
from redis.cluster import RedisCluster
|
||||
|
||||
from litellm.caching import RedisCache
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
# List of startup nodes
|
||||
startup_nodes = [
|
||||
{"host": "127.0.0.1", "port": "7001"},
|
||||
{"host": "127.0.0.1", "port": "7003"},
|
||||
{"host": "127.0.0.1", "port": "7004"},
|
||||
{"host": "127.0.0.1", "port": "7005"},
|
||||
{"host": "127.0.0.1", "port": "7006"},
|
||||
{"host": "127.0.0.1", "port": "7007"},
|
||||
]
|
||||
|
||||
# set startup nodes in environment variables
|
||||
os.environ["REDIS_CLUSTER_NODES"] = json.dumps(startup_nodes)
|
||||
print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
|
||||
|
||||
# unser REDIS_HOST, REDIS_PORT, REDIS_PASSWORD
|
||||
os.environ.pop("REDIS_HOST", None)
|
||||
os.environ.pop("REDIS_PORT", None)
|
||||
os.environ.pop("REDIS_PASSWORD", None)
|
||||
|
||||
resp = RedisCache()
|
||||
print("response from redis cache", resp)
|
||||
assert isinstance(resp.redis_client, RedisCluster)
|
||||
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
|
||||
|
||||
resp = litellm.Cache(type="redis")
|
||||
|
||||
assert isinstance(resp.cache, RedisCache)
|
||||
assert isinstance(resp.cache.redis_client, RedisCluster)
|
||||
assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
|
||||
|
||||
except Exception as e:
|
||||
print(f"{str(e)}\n\n{traceback.format_exc()}")
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_acompletion_stream():
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue