(feat) v0 batch redis cache writes

This commit is contained in:
Ishaan Jaff 2024-03-25 15:20:10 -07:00
parent e5be002adc
commit 37aadba959

View file

@ -38,6 +38,9 @@ class BaseCache:
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def batch_cache_write(self, result, *args, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError
@ -96,7 +99,9 @@ class InMemoryCache(BaseCache):
class RedisCache(BaseCache):
# if users don't provider one, use the default litellm cache
def __init__(self, host=None, port=None, password=None, **kwargs):
def __init__(
self, host=None, port=None, password=None, redis_flush_size=100, **kwargs
):
from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {}
@ -111,6 +116,10 @@ class RedisCache(BaseCache):
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
# for high traffic, we store the redis results in memory and then batch write to redis
self.redis_batch_writing_buffer = []
self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown"
try:
self.redis_version = self.redis_client.info()["redis_version"]
@ -193,6 +202,21 @@ class RedisCache(BaseCache):
except Exception as e:
print_verbose(f"Error occurred in pipeline write - {str(e)}")
async def batch_cache_write(self, key, value, **kwargs):
print_verbose("in batch cache writing for redis")
self.redis_batch_writing_buffer.append((key, value))
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer()
async def flush_cache_buffer(self):
print_verbose(
"flushing to redis....reached size of buffer",
len(self.redis_batch_writing_buffer),
)
await self.async_set_cache_pipeline(self.redis_batch_writing_buffer)
self.redis_batch_writing_buffer = []
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
@ -908,6 +932,7 @@ class Cache:
s3_path: Optional[str] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=100,
**kwargs,
):
"""
@ -930,7 +955,9 @@ class Cache:
None. Cache is set as a litellm param
"""
if type == "redis":
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
self.cache: BaseCache = RedisCache(
host, port, password, redis_flush_size, **kwargs
)
elif type == "redis-semantic":
self.cache = RedisSemanticCache(
host,
@ -1287,6 +1314,12 @@ class Cache:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def batch_cache_write(self, result, *args, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):
if hasattr(self.cache, "ping"):
return await self.cache.ping()