(Redis fix) - use mget_non_atomic (#8682)

* use mget_nonatomic

* redis cluster override MGET op

* fix redis cluster + MGET

* test redis cluster
This commit is contained in:
Ishaan Jaff 2025-02-20 17:51:31 -08:00 committed by GitHub
parent bb6f43d12e
commit ccfbb77b73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 106 additions and 8 deletions

View file

@ -637,6 +637,23 @@ class RedisCache(BaseCache):
"litellm.caching.caching: get() - Got exception from REDIS: ", e "litellm.caching.caching: get() - Got exception from REDIS: ", e
) )
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Wrapper to call `mget` on the redis client
We use a wrapper so RedisCluster can override this method
"""
return self.redis_client.mget(keys=keys) # type: ignore
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Wrapper to call `mget` on the redis client
We use a wrapper so RedisCluster can override this method
"""
async_redis_client = self.init_async_client()
return await async_redis_client.mget(keys=keys) # type: ignore
def batch_get_cache( def batch_get_cache(
self, self,
key_list: Union[List[str], List[Optional[str]]], key_list: Union[List[str], List[Optional[str]]],
@ -661,7 +678,7 @@ class RedisCache(BaseCache):
cache_key = self.check_and_fix_namespace(key=cache_key or "") cache_key = self.check_and_fix_namespace(key=cache_key or "")
_keys.append(cache_key) _keys.append(cache_key)
start_time = time.time() start_time = time.time()
results: List = self.redis_client.mget(keys=_keys) # type: ignore results: List = self._run_redis_mget_operation(keys=_keys)
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time
self.service_logger_obj.service_success_hook( self.service_logger_obj.service_success_hook(
@ -757,7 +774,6 @@ class RedisCache(BaseCache):
`.mget` does not support None keys. This will filter out None keys. `.mget` does not support None keys. This will filter out None keys.
""" """
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget` # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
_redis_client: Any = self.init_async_client()
key_value_dict = {} key_value_dict = {}
start_time = time.time() start_time = time.time()
_key_list = [key for key in key_list if key is not None] _key_list = [key for key in key_list if key is not None]
@ -766,7 +782,7 @@ class RedisCache(BaseCache):
for cache_key in _key_list: for cache_key in _key_list:
cache_key = self.check_and_fix_namespace(key=cache_key) cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key) _keys.append(cache_key)
results = await _redis_client.mget(keys=_keys) results = await self._async_run_redis_mget_operation(keys=_keys)
## LOGGING ## ## LOGGING ##
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time

View file

@ -5,7 +5,7 @@ Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created - RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
""" """
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, List, Optional
from litellm.caching.redis_cache import RedisCache from litellm.caching.redis_cache import RedisCache
@ -26,19 +26,34 @@ else:
class RedisClusterCache(RedisCache): class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.redis_cluster_client: Optional[RedisCluster] = None self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
def init_async_client(self): def init_async_client(self):
from redis.asyncio import RedisCluster from redis.asyncio import RedisCluster
from .._redis import get_redis_async_client from .._redis import get_redis_async_client
if self.redis_cluster_client: if self.redis_async_redis_cluster_client:
return self.redis_cluster_client return self.redis_async_redis_cluster_client
_redis_client = get_redis_async_client( _redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
) )
if isinstance(_redis_client, RedisCluster): if isinstance(_redis_client, RedisCluster):
self.redis_cluster_client = _redis_client self.redis_async_redis_cluster_client = _redis_client
return _redis_client return _redis_client
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_run_redis_mget_operation` in redis_cache.py
"""
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_async_run_redis_mget_operation` in redis_cache.py
"""
async_redis_cluster_client = self.init_async_client()
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore

View file

@ -15,3 +15,4 @@ litellm_settings:
cache_params: cache_params:
type: redis type: redis
ttl: 600 ttl: 600
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7100"}]

View file

@ -0,0 +1,66 @@
import json
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.caching.redis_cluster_cache import RedisClusterCache
@patch("litellm._redis.init_redis_cluster")
def test_redis_cluster_batch_get(mock_init_redis_cluster):
"""
Test that RedisClusterCache uses mget_nonatomic instead of mget for batch operations
"""
# Create a mock Redis client
mock_redis = MagicMock()
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
mock_init_redis_cluster.return_value = mock_redis
# Create RedisClusterCache instance with mock client
cache = RedisClusterCache(
startup_nodes=[{"host": "localhost", "port": 6379}],
password="hello",
)
# Test batch_get_cache
keys = ["key1", "key2"]
cache.batch_get_cache(keys)
# Verify mget_nonatomic was called instead of mget
mock_redis.mget_nonatomic.assert_called_once()
assert not mock_redis.mget.called
@pytest.mark.asyncio
@patch("litellm._redis.init_redis_cluster")
async def test_redis_cluster_async_batch_get(mock_init_redis_cluster):
"""
Test that RedisClusterCache uses mget_nonatomic instead of mget for async batch operations
"""
# Create a mock Redis client
mock_redis = MagicMock()
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
# Create RedisClusterCache instance with mock client
cache = RedisClusterCache(
startup_nodes=[{"host": "localhost", "port": 6379}],
password="hello",
)
# Mock the init_async_client to return our mock redis client
cache.init_async_client = MagicMock(return_value=mock_redis)
# Test async_batch_get_cache
keys = ["key1", "key2"]
await cache.async_batch_get_cache(keys)
# Verify mget_nonatomic was called instead of mget
mock_redis.mget_nonatomic.assert_called_once()
assert not mock_redis.mget.called