mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(Bug Fix Redis) - Fix running redis.mget operations with None
Keys (#8666)
* async_batch_get_cache * test_batch_get_cache_with_none_keys * async_batch_get_cache * fix linting error
This commit is contained in:
parent
752e93cbdb
commit
5dc62c9e7b
2 changed files with 79 additions and 15 deletions
|
@ -637,16 +637,28 @@ class RedisCache(BaseCache):
|
||||||
"litellm.caching.caching: get() - Got exception from REDIS: ", e
|
"litellm.caching.caching: get() - Got exception from REDIS: ", e
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_get_cache(self, key_list, parent_otel_span: Optional[Span]) -> dict:
|
def batch_get_cache(
|
||||||
|
self,
|
||||||
|
key_list: Union[List[str], List[Optional[str]]],
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Use Redis for bulk read operations
|
Use Redis for bulk read operations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_list: List of keys to get from Redis
|
||||||
|
parent_otel_span: Optional parent OpenTelemetry span
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary mapping keys to their cached values
|
||||||
"""
|
"""
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
|
_key_list = [key for key in key_list if key is not None]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_keys = []
|
_keys = []
|
||||||
for cache_key in key_list:
|
for cache_key in _key_list:
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
cache_key = self.check_and_fix_namespace(key=cache_key or "")
|
||||||
_keys.append(cache_key)
|
_keys.append(cache_key)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results: List = self.redis_client.mget(keys=_keys) # type: ignore
|
results: List = self.redis_client.mget(keys=_keys) # type: ignore
|
||||||
|
@ -662,17 +674,19 @@ class RedisCache(BaseCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Associate the results back with their keys.
|
# Associate the results back with their keys.
|
||||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
# 'results' is a list of values corresponding to the order of keys in '_key_list'.
|
||||||
key_value_dict = dict(zip(key_list, results))
|
key_value_dict = dict(zip(_key_list, results))
|
||||||
|
|
||||||
decoded_results = {
|
decoded_results = {}
|
||||||
k.decode("utf-8"): self._get_cache_logic(v)
|
for k, v in key_value_dict.items():
|
||||||
for k, v in key_value_dict.items()
|
if isinstance(k, bytes):
|
||||||
}
|
k = k.decode("utf-8")
|
||||||
|
v = self._get_cache_logic(v)
|
||||||
|
decoded_results[k] = v
|
||||||
|
|
||||||
return decoded_results
|
return decoded_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
verbose_logger.error(f"Error occurred in batch get cache - {str(e)}")
|
||||||
return key_value_dict
|
return key_value_dict
|
||||||
|
|
||||||
async def async_get_cache(
|
async def async_get_cache(
|
||||||
|
@ -726,22 +740,33 @@ class RedisCache(BaseCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_batch_get_cache(
|
async def async_batch_get_cache(
|
||||||
self, key_list: List[str], parent_otel_span: Optional[Span] = None
|
self,
|
||||||
|
key_list: Union[List[str], List[Optional[str]]],
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Use Redis for bulk read operations
|
Use Redis for bulk read operations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_list: List of keys to get from Redis
|
||||||
|
parent_otel_span: Optional parent OpenTelemetry span
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary mapping keys to their cached values
|
||||||
|
|
||||||
|
`.mget` does not support None keys. This will filter out None keys.
|
||||||
"""
|
"""
|
||||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
||||||
_redis_client: Any = self.init_async_client()
|
_redis_client: Any = self.init_async_client()
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
_key_list = [key for key in key_list if key is not None]
|
||||||
try:
|
try:
|
||||||
_keys = []
|
_keys = []
|
||||||
for cache_key in key_list:
|
for cache_key in _key_list:
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
_keys.append(cache_key)
|
_keys.append(cache_key)
|
||||||
results = await _redis_client.mget(keys=_keys)
|
results = await _redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
|
@ -758,7 +783,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
# Associate the results back with their keys.
|
# Associate the results back with their keys.
|
||||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
key_value_dict = dict(zip(key_list, results))
|
key_value_dict = dict(zip(_key_list, results))
|
||||||
|
|
||||||
decoded_results = {}
|
decoded_results = {}
|
||||||
for k, v in key_value_dict.items():
|
for k, v in key_value_dict.items():
|
||||||
|
@ -783,7 +808,7 @@ class RedisCache(BaseCache):
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}")
|
||||||
return key_value_dict
|
return key_value_dict
|
||||||
|
|
||||||
def sync_ping(self) -> bool:
|
def sync_ping(self) -> bool:
|
||||||
|
|
|
@ -94,6 +94,45 @@ def test_dual_cache_batch_get_cache():
|
||||||
assert result[1] == None
|
assert result[1] == None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_get_cache_with_none_keys(sync_mode):
|
||||||
|
"""
|
||||||
|
Unit testing for RedisCache batch_get_cache() and async_batch_get_cache()
|
||||||
|
- test with None keys. Ensure it can safely handle when keys are None.
|
||||||
|
- expect result = {key: None}
|
||||||
|
"""
|
||||||
|
from litellm.caching.caching import RedisCache
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
redis_cache = RedisCache(
|
||||||
|
host=os.environ.get("REDIS_HOST"),
|
||||||
|
port=os.environ.get("REDIS_PORT"),
|
||||||
|
password=os.environ.get("REDIS_PASSWORD"),
|
||||||
|
)
|
||||||
|
keys_to_lookup = [
|
||||||
|
None,
|
||||||
|
f"test_value_{uuid.uuid4()}",
|
||||||
|
None,
|
||||||
|
f"test_value_2_{uuid.uuid4()}",
|
||||||
|
None,
|
||||||
|
f"test_value_3_{uuid.uuid4()}",
|
||||||
|
]
|
||||||
|
if sync_mode:
|
||||||
|
result = redis_cache.batch_get_cache(key_list=keys_to_lookup)
|
||||||
|
print("result from batch_get_cache=", result)
|
||||||
|
else:
|
||||||
|
result = await redis_cache.async_batch_get_cache(key_list=keys_to_lookup)
|
||||||
|
print("result from async_batch_get_cache=", result)
|
||||||
|
expected_result = {}
|
||||||
|
for key in keys_to_lookup:
|
||||||
|
if key is None:
|
||||||
|
continue
|
||||||
|
expected_result[key] = None
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.skip(reason="")
|
# @pytest.mark.skip(reason="")
|
||||||
def test_caching_dynamic_args(): # test in memory cache
|
def test_caching_dynamic_args(): # test in memory cache
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue