diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index a272d2dcf..a82d7bc8e 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -837,6 +837,52 @@ async def test_redis_cache_cluster_init_unit_test(): raise e +@pytest.mark.asyncio +async def test_redis_cache_cluster_init_with_env_vars_unit_test(): + try: + import json + + from redis.asyncio import RedisCluster as AsyncRedisCluster + from redis.cluster import RedisCluster + + from litellm.caching import RedisCache + + litellm.set_verbose = True + + # List of startup nodes + startup_nodes = [ + {"host": "127.0.0.1", "port": "7001"}, + {"host": "127.0.0.1", "port": "7003"}, + {"host": "127.0.0.1", "port": "7004"}, + {"host": "127.0.0.1", "port": "7005"}, + {"host": "127.0.0.1", "port": "7006"}, + {"host": "127.0.0.1", "port": "7007"}, + ] + + # set startup nodes in environment variables + os.environ["REDIS_CLUSTER_NODES"] = json.dumps(startup_nodes) + + # unser REDIS_HOST, REDIS_PORT, REDIS_PASSWORD + os.environ.pop("REDIS_HOST", None) + os.environ.pop("REDIS_PORT", None) + os.environ.pop("REDIS_PASSWORD", None) + + resp = RedisCache() + print("response from redis cache", resp) + assert isinstance(resp.redis_client, RedisCluster) + assert isinstance(resp.init_async_client(), AsyncRedisCluster) + + resp = litellm.Cache(type="redis") + + assert isinstance(resp.cache, RedisCache) + assert isinstance(resp.cache.redis_client, RedisCluster) + assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster) + + except Exception as e: + print(f"{str(e)}\n\n{traceback.format_exc()}") + raise e + + @pytest.mark.asyncio async def test_redis_cache_acompletion_stream(): try: