From 5dc62c9e7ba49b4fe4428e5283f5f52c28b7b615 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 19 Feb 2025 19:56:57 -0800 Subject: [PATCH] (Bug Fix Redis) - Fix running redis.mget operations with `None` Keys (#8666) * async_batch_get_cache * test_batch_get_cache_with_none_keys * async_batch_get_cache * fix linting error --- litellm/caching/redis_cache.py | 55 +++++++++++++++++++++-------- tests/local_testing/test_caching.py | 39 ++++++++++++++++++++ 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index d21f72fe6b..0451336b80 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -637,16 +637,28 @@ class RedisCache(BaseCache): "litellm.caching.caching: get() - Got exception from REDIS: ", e ) - def batch_get_cache(self, key_list, parent_otel_span: Optional[Span]) -> dict: + def batch_get_cache( + self, + key_list: Union[List[str], List[Optional[str]]], + parent_otel_span: Optional[Span] = None, + ) -> dict: """ Use Redis for bulk read operations + + Args: + key_list: List of keys to get from Redis + parent_otel_span: Optional parent OpenTelemetry span + + Returns: + dict: A dictionary mapping keys to their cached values """ key_value_dict = {} + _key_list = [key for key in key_list if key is not None] try: _keys = [] - for cache_key in key_list: - cache_key = self.check_and_fix_namespace(key=cache_key) + for cache_key in _key_list: + cache_key = self.check_and_fix_namespace(key=cache_key or "") _keys.append(cache_key) start_time = time.time() results: List = self.redis_client.mget(keys=_keys) # type: ignore @@ -662,17 +674,19 @@ class RedisCache(BaseCache): ) # Associate the results back with their keys. - # 'results' is a list of values corresponding to the order of keys in 'key_list'. - key_value_dict = dict(zip(key_list, results)) + # 'results' is a list of values corresponding to the order of keys in '_key_list'. + key_value_dict = dict(zip(_key_list, results)) - decoded_results = { - k.decode("utf-8"): self._get_cache_logic(v) - for k, v in key_value_dict.items() - } + decoded_results = {} + for k, v in key_value_dict.items(): + if isinstance(k, bytes): + k = k.decode("utf-8") + v = self._get_cache_logic(v) + decoded_results[k] = v return decoded_results except Exception as e: - print_verbose(f"Error occurred in pipeline read - {str(e)}") + verbose_logger.error(f"Error occurred in batch get cache - {str(e)}") return key_value_dict async def async_get_cache( @@ -726,22 +740,33 @@ class RedisCache(BaseCache): ) async def async_batch_get_cache( - self, key_list: List[str], parent_otel_span: Optional[Span] = None + self, + key_list: Union[List[str], List[Optional[str]]], + parent_otel_span: Optional[Span] = None, ) -> dict: """ Use Redis for bulk read operations + + Args: + key_list: List of keys to get from Redis + parent_otel_span: Optional parent OpenTelemetry span + + Returns: + dict: A dictionary mapping keys to their cached values + + `.mget` does not support None keys. This will filter out None keys. """ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget` _redis_client: Any = self.init_async_client() key_value_dict = {} start_time = time.time() + _key_list = [key for key in key_list if key is not None] try: _keys = [] - for cache_key in key_list: + for cache_key in _key_list: cache_key = self.check_and_fix_namespace(key=cache_key) _keys.append(cache_key) results = await _redis_client.mget(keys=_keys) - ## LOGGING ## end_time = time.time() _duration = end_time - start_time @@ -758,7 +783,7 @@ class RedisCache(BaseCache): # Associate the results back with their keys. # 'results' is a list of values corresponding to the order of keys in 'key_list'. - key_value_dict = dict(zip(key_list, results)) + key_value_dict = dict(zip(_key_list, results)) decoded_results = {} for k, v in key_value_dict.items(): @@ -783,7 +808,7 @@ class RedisCache(BaseCache): parent_otel_span=parent_otel_span, ) ) - print_verbose(f"Error occurred in pipeline read - {str(e)}") + verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}") return key_value_dict def sync_ping(self) -> bool: diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index ae1e4d38c3..b384cead53 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -94,6 +94,45 @@ def test_dual_cache_batch_get_cache(): assert result[1] == None +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_batch_get_cache_with_none_keys(sync_mode): + """ + Unit testing for RedisCache batch_get_cache() and async_batch_get_cache() + - test with None keys. Ensure it can safely handle when keys are None. + - expect result = {key: None} + """ + from litellm.caching.caching import RedisCache + + litellm._turn_on_debug() + + redis_cache = RedisCache( + host=os.environ.get("REDIS_HOST"), + port=os.environ.get("REDIS_PORT"), + password=os.environ.get("REDIS_PASSWORD"), + ) + keys_to_lookup = [ + None, + f"test_value_{uuid.uuid4()}", + None, + f"test_value_2_{uuid.uuid4()}", + None, + f"test_value_3_{uuid.uuid4()}", + ] + if sync_mode: + result = redis_cache.batch_get_cache(key_list=keys_to_lookup) + print("result from batch_get_cache=", result) + else: + result = await redis_cache.async_batch_get_cache(key_list=keys_to_lookup) + print("result from async_batch_get_cache=", result) + expected_result = {} + for key in keys_to_lookup: + if key is None: + continue + expected_result[key] = None + assert result == expected_result + + # @pytest.mark.skip(reason="") def test_caching_dynamic_args(): # test in memory cache try: