mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Redis fix) - use mget_non_atomic (#8682)
* use mget_nonatomic * redis cluster override MGET op * fix redis cluster + MGET * test redis cluster
This commit is contained in:
parent
bb6f43d12e
commit
ccfbb77b73
4 changed files with 106 additions and 8 deletions
|
@ -637,6 +637,23 @@ class RedisCache(BaseCache):
|
||||||
"litellm.caching.caching: get() - Got exception from REDIS: ", e
|
"litellm.caching.caching: get() - Got exception from REDIS: ", e
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Wrapper to call `mget` on the redis client
|
||||||
|
|
||||||
|
We use a wrapper so RedisCluster can override this method
|
||||||
|
"""
|
||||||
|
return self.redis_client.mget(keys=keys) # type: ignore
|
||||||
|
|
||||||
|
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Wrapper to call `mget` on the redis client
|
||||||
|
|
||||||
|
We use a wrapper so RedisCluster can override this method
|
||||||
|
"""
|
||||||
|
async_redis_client = self.init_async_client()
|
||||||
|
return await async_redis_client.mget(keys=keys) # type: ignore
|
||||||
|
|
||||||
def batch_get_cache(
|
def batch_get_cache(
|
||||||
self,
|
self,
|
||||||
key_list: Union[List[str], List[Optional[str]]],
|
key_list: Union[List[str], List[Optional[str]]],
|
||||||
|
@ -661,7 +678,7 @@ class RedisCache(BaseCache):
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key or "")
|
cache_key = self.check_and_fix_namespace(key=cache_key or "")
|
||||||
_keys.append(cache_key)
|
_keys.append(cache_key)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results: List = self.redis_client.mget(keys=_keys) # type: ignore
|
results: List = self._run_redis_mget_operation(keys=_keys)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
self.service_logger_obj.service_success_hook(
|
self.service_logger_obj.service_success_hook(
|
||||||
|
@ -757,7 +774,6 @@ class RedisCache(BaseCache):
|
||||||
`.mget` does not support None keys. This will filter out None keys.
|
`.mget` does not support None keys. This will filter out None keys.
|
||||||
"""
|
"""
|
||||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
||||||
_redis_client: Any = self.init_async_client()
|
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
_key_list = [key for key in key_list if key is not None]
|
_key_list = [key for key in key_list if key is not None]
|
||||||
|
@ -766,7 +782,7 @@ class RedisCache(BaseCache):
|
||||||
for cache_key in _key_list:
|
for cache_key in _key_list:
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
_keys.append(cache_key)
|
_keys.append(cache_key)
|
||||||
results = await _redis_client.mget(keys=_keys)
|
results = await self._async_run_redis_mget_operation(keys=_keys)
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
|
|
|
@ -5,7 +5,7 @@ Key differences:
|
||||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional
|
||||||
|
|
||||||
from litellm.caching.redis_cache import RedisCache
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
@ -26,19 +26,34 @@ else:
|
||||||
class RedisClusterCache(RedisCache):
|
class RedisClusterCache(RedisCache):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.redis_cluster_client: Optional[RedisCluster] = None
|
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
|
||||||
|
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
|
||||||
|
|
||||||
def init_async_client(self):
|
def init_async_client(self):
|
||||||
from redis.asyncio import RedisCluster
|
from redis.asyncio import RedisCluster
|
||||||
|
|
||||||
from .._redis import get_redis_async_client
|
from .._redis import get_redis_async_client
|
||||||
|
|
||||||
if self.redis_cluster_client:
|
if self.redis_async_redis_cluster_client:
|
||||||
return self.redis_cluster_client
|
return self.redis_async_redis_cluster_client
|
||||||
|
|
||||||
_redis_client = get_redis_async_client(
|
_redis_client = get_redis_async_client(
|
||||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||||
)
|
)
|
||||||
if isinstance(_redis_client, RedisCluster):
|
if isinstance(_redis_client, RedisCluster):
|
||||||
self.redis_cluster_client = _redis_client
|
self.redis_async_redis_cluster_client = _redis_client
|
||||||
|
|
||||||
return _redis_client
|
return _redis_client
|
||||||
|
|
||||||
|
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Overrides `_run_redis_mget_operation` in redis_cache.py
|
||||||
|
"""
|
||||||
|
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
|
||||||
|
|
||||||
|
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Overrides `_async_run_redis_mget_operation` in redis_cache.py
|
||||||
|
"""
|
||||||
|
async_redis_cluster_client = self.init_async_client()
|
||||||
|
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
|
||||||
|
|
|
@ -15,3 +15,4 @@ litellm_settings:
|
||||||
cache_params:
|
cache_params:
|
||||||
type: redis
|
type: redis
|
||||||
ttl: 600
|
ttl: 600
|
||||||
|
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7100"}]
|
||||||
|
|
66
tests/litellm/caching/test_redis_cluster_cache.py
Normal file
66
tests/litellm/caching/test_redis_cluster_cache.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.caching.redis_cluster_cache import RedisClusterCache
|
||||||
|
|
||||||
|
|
||||||
|
@patch("litellm._redis.init_redis_cluster")
|
||||||
|
def test_redis_cluster_batch_get(mock_init_redis_cluster):
|
||||||
|
"""
|
||||||
|
Test that RedisClusterCache uses mget_nonatomic instead of mget for batch operations
|
||||||
|
"""
|
||||||
|
# Create a mock Redis client
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
|
||||||
|
mock_init_redis_cluster.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create RedisClusterCache instance with mock client
|
||||||
|
cache = RedisClusterCache(
|
||||||
|
startup_nodes=[{"host": "localhost", "port": 6379}],
|
||||||
|
password="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test batch_get_cache
|
||||||
|
keys = ["key1", "key2"]
|
||||||
|
cache.batch_get_cache(keys)
|
||||||
|
|
||||||
|
# Verify mget_nonatomic was called instead of mget
|
||||||
|
mock_redis.mget_nonatomic.assert_called_once()
|
||||||
|
assert not mock_redis.mget.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("litellm._redis.init_redis_cluster")
|
||||||
|
async def test_redis_cluster_async_batch_get(mock_init_redis_cluster):
|
||||||
|
"""
|
||||||
|
Test that RedisClusterCache uses mget_nonatomic instead of mget for async batch operations
|
||||||
|
"""
|
||||||
|
# Create a mock Redis client
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
|
||||||
|
|
||||||
|
# Create RedisClusterCache instance with mock client
|
||||||
|
cache = RedisClusterCache(
|
||||||
|
startup_nodes=[{"host": "localhost", "port": 6379}],
|
||||||
|
password="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the init_async_client to return our mock redis client
|
||||||
|
cache.init_async_client = MagicMock(return_value=mock_redis)
|
||||||
|
|
||||||
|
# Test async_batch_get_cache
|
||||||
|
keys = ["key1", "key2"]
|
||||||
|
await cache.async_batch_get_cache(keys)
|
||||||
|
|
||||||
|
# Verify mget_nonatomic was called instead of mget
|
||||||
|
mock_redis.mget_nonatomic.assert_called_once()
|
||||||
|
assert not mock_redis.mget.called
|
Loading…
Add table
Add a link
Reference in a new issue