diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 0451336b80..960d19c3f8 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -637,6 +637,23 @@ class RedisCache(BaseCache): "litellm.caching.caching: get() - Got exception from REDIS: ", e ) + def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Wrapper to call `mget` on the redis client + + We use a wrapper so RedisCluster can override this method + """ + return self.redis_client.mget(keys=keys) # type: ignore + + async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Wrapper to call `mget` on the redis client + + We use a wrapper so RedisCluster can override this method + """ + async_redis_client = self.init_async_client() + return await async_redis_client.mget(keys=keys) # type: ignore + def batch_get_cache( self, key_list: Union[List[str], List[Optional[str]]], @@ -661,7 +678,7 @@ class RedisCache(BaseCache): cache_key = self.check_and_fix_namespace(key=cache_key or "") _keys.append(cache_key) start_time = time.time() - results: List = self.redis_client.mget(keys=_keys) # type: ignore + results: List = self._run_redis_mget_operation(keys=_keys) end_time = time.time() _duration = end_time - start_time self.service_logger_obj.service_success_hook( @@ -757,7 +774,6 @@ class RedisCache(BaseCache): `.mget` does not support None keys. This will filter out None keys. """ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget` - _redis_client: Any = self.init_async_client() key_value_dict = {} start_time = time.time() _key_list = [key for key in key_list if key is not None] @@ -766,7 +782,7 @@ class RedisCache(BaseCache): for cache_key in _key_list: cache_key = self.check_and_fix_namespace(key=cache_key) _keys.append(cache_key) - results = await _redis_client.mget(keys=_keys) + results = await self._async_run_redis_mget_operation(keys=_keys) ## LOGGING ## end_time = time.time() _duration = end_time - start_time diff --git a/litellm/caching/redis_cluster_cache.py b/litellm/caching/redis_cluster_cache.py index 397ea89790..2e7d1de17f 100644 --- a/litellm/caching/redis_cluster_cache.py +++ b/litellm/caching/redis_cluster_cache.py @@ -5,7 +5,7 @@ Key differences: - RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created """ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, List, Optional from litellm.caching.redis_cache import RedisCache @@ -26,19 +26,34 @@ else: class RedisClusterCache(RedisCache): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.redis_cluster_client: Optional[RedisCluster] = None + self.redis_async_redis_cluster_client: Optional[RedisCluster] = None + self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None def init_async_client(self): from redis.asyncio import RedisCluster from .._redis import get_redis_async_client - if self.redis_cluster_client: - return self.redis_cluster_client + if self.redis_async_redis_cluster_client: + return self.redis_async_redis_cluster_client _redis_client = get_redis_async_client( connection_pool=self.async_redis_conn_pool, **self.redis_kwargs ) if isinstance(_redis_client, RedisCluster): - self.redis_cluster_client = _redis_client + self.redis_async_redis_cluster_client = _redis_client + return _redis_client + + def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Overrides `_run_redis_mget_operation` in redis_cache.py + """ + return self.redis_client.mget_nonatomic(keys=keys) # type: ignore + + async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Overrides `_async_run_redis_mget_operation` in redis_cache.py + """ + async_redis_cluster_client = self.init_async_client() + return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index e493b26dd8..a8479264e9 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -15,3 +15,4 @@ litellm_settings: cache_params: type: redis ttl: 600 + redis_startup_nodes: [{"host": "127.0.0.1", "port": "7100"}] diff --git a/tests/litellm/caching/test_redis_cluster_cache.py b/tests/litellm/caching/test_redis_cluster_cache.py new file mode 100644 index 0000000000..a23a1296f2 --- /dev/null +++ b/tests/litellm/caching/test_redis_cluster_cache.py @@ -0,0 +1,66 @@ +import json +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.caching.redis_cluster_cache import RedisClusterCache + + +@patch("litellm._redis.init_redis_cluster") +def test_redis_cluster_batch_get(mock_init_redis_cluster): + """ + Test that RedisClusterCache uses mget_nonatomic instead of mget for batch operations + """ + # Create a mock Redis client + mock_redis = MagicMock() + mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits + mock_init_redis_cluster.return_value = mock_redis + + # Create RedisClusterCache instance with mock client + cache = RedisClusterCache( + startup_nodes=[{"host": "localhost", "port": 6379}], + password="hello", + ) + + # Test batch_get_cache + keys = ["key1", "key2"] + cache.batch_get_cache(keys) + + # Verify mget_nonatomic was called instead of mget + mock_redis.mget_nonatomic.assert_called_once() + assert not mock_redis.mget.called + + +@pytest.mark.asyncio +@patch("litellm._redis.init_redis_cluster") +async def test_redis_cluster_async_batch_get(mock_init_redis_cluster): + """ + Test that RedisClusterCache uses mget_nonatomic instead of mget for async batch operations + """ + # Create a mock Redis client + mock_redis = MagicMock() + mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits + + # Create RedisClusterCache instance with mock client + cache = RedisClusterCache( + startup_nodes=[{"host": "localhost", "port": 6379}], + password="hello", + ) + + # Mock the init_async_client to return our mock redis client + cache.init_async_client = MagicMock(return_value=mock_redis) + + # Test async_batch_get_cache + keys = ["key1", "key2"] + await cache.async_batch_get_cache(keys) + + # Verify mget_nonatomic was called instead of mget + mock_redis.mget_nonatomic.assert_called_once() + assert not mock_redis.mget.called