diff --git a/litellm/caching.py b/litellm/caching.py index 5a90083428..79c030ad53 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -38,6 +38,9 @@ class BaseCache: async def async_get_cache(self, key, **kwargs): raise NotImplementedError + async def batch_cache_write(self, result, *args, **kwargs): + raise NotImplementedError + async def disconnect(self): raise NotImplementedError @@ -96,7 +99,9 @@ class InMemoryCache(BaseCache): class RedisCache(BaseCache): # if users don't provider one, use the default litellm cache - def __init__(self, host=None, port=None, password=None, **kwargs): + def __init__( + self, host=None, port=None, password=None, redis_flush_size=100, **kwargs + ): from ._redis import get_redis_client, get_redis_connection_pool redis_kwargs = {} @@ -111,6 +116,10 @@ class RedisCache(BaseCache): self.redis_client = get_redis_client(**redis_kwargs) self.redis_kwargs = redis_kwargs self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) + + # for high traffic, we store the redis results in memory and then batch write to redis + self.redis_batch_writing_buffer = [] + self.redis_flush_size = redis_flush_size self.redis_version = "Unknown" try: self.redis_version = self.redis_client.info()["redis_version"] @@ -193,6 +202,21 @@ class RedisCache(BaseCache): except Exception as e: print_verbose(f"Error occurred in pipeline write - {str(e)}") + async def batch_cache_write(self, key, value, **kwargs): + print_verbose("in batch cache writing for redis") + + self.redis_batch_writing_buffer.append((key, value)) + if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: + await self.flush_cache_buffer() + + async def flush_cache_buffer(self): + print_verbose( + "flushing to redis....reached size of buffer", + len(self.redis_batch_writing_buffer), + ) + await self.async_set_cache_pipeline(self.redis_batch_writing_buffer) + self.redis_batch_writing_buffer = [] + def _get_cache_logic(self, cached_response: Any): """ Common 'get_cache_logic' across sync + async redis client implementations @@ -908,6 +932,7 @@ class Cache: s3_path: Optional[str] = None, redis_semantic_cache_use_async=False, redis_semantic_cache_embedding_model="text-embedding-ada-002", + redis_flush_size=100, **kwargs, ): """ @@ -930,7 +955,9 @@ class Cache: None. Cache is set as a litellm param """ if type == "redis": - self.cache: BaseCache = RedisCache(host, port, password, **kwargs) + self.cache: BaseCache = RedisCache( + host, port, password, redis_flush_size, **kwargs + ) elif type == "redis-semantic": self.cache = RedisSemanticCache( host, @@ -1287,6 +1314,12 @@ class Cache: print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") traceback.print_exc() + async def batch_cache_write(self, result, *args, **kwargs): + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, *args, **kwargs + ) + await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) + async def ping(self): if hasattr(self.cache, "ping"): return await self.cache.ping()