(Bug Fix Redis) - Fix running redis.mget operations with None Keys (#8666)

* async_batch_get_cache

* test_batch_get_cache_with_none_keys

* async_batch_get_cache

* fix linting error
This commit is contained in:
Ishaan Jaff 2025-02-19 19:56:57 -08:00 committed by GitHub
parent 752e93cbdb
commit 5dc62c9e7b
2 changed files with 79 additions and 15 deletions

View file

@ -637,16 +637,28 @@ class RedisCache(BaseCache):
"litellm.caching.caching: get() - Got exception from REDIS: ", e
)
def batch_get_cache(self, key_list, parent_otel_span: Optional[Span]) -> dict:
def batch_get_cache(
self,
key_list: Union[List[str], List[Optional[str]]],
parent_otel_span: Optional[Span] = None,
) -> dict:
"""
Use Redis for bulk read operations
Args:
key_list: List of keys to get from Redis
parent_otel_span: Optional parent OpenTelemetry span
Returns:
dict: A dictionary mapping keys to their cached values
"""
key_value_dict = {}
_key_list = [key for key in key_list if key is not None]
try:
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
for cache_key in _key_list:
cache_key = self.check_and_fix_namespace(key=cache_key or "")
_keys.append(cache_key)
start_time = time.time()
results: List = self.redis_client.mget(keys=_keys) # type: ignore
@ -662,17 +674,19 @@ class RedisCache(BaseCache):
)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
# 'results' is a list of values corresponding to the order of keys in '_key_list'.
key_value_dict = dict(zip(_key_list, results))
decoded_results = {
k.decode("utf-8"): self._get_cache_logic(v)
for k, v in key_value_dict.items()
}
decoded_results = {}
for k, v in key_value_dict.items():
if isinstance(k, bytes):
k = k.decode("utf-8")
v = self._get_cache_logic(v)
decoded_results[k] = v
return decoded_results
except Exception as e:
print_verbose(f"Error occurred in pipeline read - {str(e)}")
verbose_logger.error(f"Error occurred in batch get cache - {str(e)}")
return key_value_dict
async def async_get_cache(
@ -726,22 +740,33 @@ class RedisCache(BaseCache):
)
async def async_batch_get_cache(
self, key_list: List[str], parent_otel_span: Optional[Span] = None
self,
key_list: Union[List[str], List[Optional[str]]],
parent_otel_span: Optional[Span] = None,
) -> dict:
"""
Use Redis for bulk read operations
Args:
key_list: List of keys to get from Redis
parent_otel_span: Optional parent OpenTelemetry span
Returns:
dict: A dictionary mapping keys to their cached values
`.mget` does not support None keys. This will filter out None keys.
"""
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
_redis_client: Any = self.init_async_client()
key_value_dict = {}
start_time = time.time()
_key_list = [key for key in key_list if key is not None]
try:
_keys = []
for cache_key in key_list:
for cache_key in _key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = await _redis_client.mget(keys=_keys)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
@ -758,7 +783,7 @@ class RedisCache(BaseCache):
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
key_value_dict = dict(zip(_key_list, results))
decoded_results = {}
for k, v in key_value_dict.items():
@ -783,7 +808,7 @@ class RedisCache(BaseCache):
parent_otel_span=parent_otel_span,
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")
verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}")
return key_value_dict
def sync_ping(self) -> bool: