(Redis fix) - use mget_non_atomic (#8682)

* use mget_nonatomic

* redis cluster override MGET op

* fix redis cluster + MGET

* test redis cluster
This commit is contained in:
Ishaan Jaff 2025-02-20 17:51:31 -08:00 committed by GitHub
parent bb6f43d12e
commit ccfbb77b73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 106 additions and 8 deletions

View file

@ -637,6 +637,23 @@ class RedisCache(BaseCache):
"litellm.caching.caching: get() - Got exception from REDIS: ", e
)
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Wrapper to call `mget` on the redis client
We use a wrapper so RedisCluster can override this method
"""
return self.redis_client.mget(keys=keys) # type: ignore
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Wrapper to call `mget` on the redis client
We use a wrapper so RedisCluster can override this method
"""
async_redis_client = self.init_async_client()
return await async_redis_client.mget(keys=keys) # type: ignore
def batch_get_cache(
self,
key_list: Union[List[str], List[Optional[str]]],
@ -661,7 +678,7 @@ class RedisCache(BaseCache):
cache_key = self.check_and_fix_namespace(key=cache_key or "")
_keys.append(cache_key)
start_time = time.time()
results: List = self.redis_client.mget(keys=_keys) # type: ignore
results: List = self._run_redis_mget_operation(keys=_keys)
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_success_hook(
@ -757,7 +774,6 @@ class RedisCache(BaseCache):
`.mget` does not support None keys. This will filter out None keys.
"""
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
_redis_client: Any = self.init_async_client()
key_value_dict = {}
start_time = time.time()
_key_list = [key for key in key_list if key is not None]
@ -766,7 +782,7 @@ class RedisCache(BaseCache):
for cache_key in _key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = await _redis_client.mget(keys=_keys)
results = await self._async_run_redis_mget_operation(keys=_keys)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time