mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) v0 batch redis cache writes
This commit is contained in:
parent
e5be002adc
commit
37aadba959
1 changed files with 35 additions and 2 deletions
|
@ -38,6 +38,9 @@ class BaseCache:
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def batch_cache_write(self, result, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -96,7 +99,9 @@ class InMemoryCache(BaseCache):
|
||||||
class RedisCache(BaseCache):
|
class RedisCache(BaseCache):
|
||||||
# if users don't provider one, use the default litellm cache
|
# if users don't provider one, use the default litellm cache
|
||||||
|
|
||||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
def __init__(
|
||||||
|
self, host=None, port=None, password=None, redis_flush_size=100, **kwargs
|
||||||
|
):
|
||||||
from ._redis import get_redis_client, get_redis_connection_pool
|
from ._redis import get_redis_client, get_redis_connection_pool
|
||||||
|
|
||||||
redis_kwargs = {}
|
redis_kwargs = {}
|
||||||
|
@ -111,6 +116,10 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client = get_redis_client(**redis_kwargs)
|
self.redis_client = get_redis_client(**redis_kwargs)
|
||||||
self.redis_kwargs = redis_kwargs
|
self.redis_kwargs = redis_kwargs
|
||||||
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
||||||
|
|
||||||
|
# for high traffic, we store the redis results in memory and then batch write to redis
|
||||||
|
self.redis_batch_writing_buffer = []
|
||||||
|
self.redis_flush_size = redis_flush_size
|
||||||
self.redis_version = "Unknown"
|
self.redis_version = "Unknown"
|
||||||
try:
|
try:
|
||||||
self.redis_version = self.redis_client.info()["redis_version"]
|
self.redis_version = self.redis_client.info()["redis_version"]
|
||||||
|
@ -193,6 +202,21 @@ class RedisCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Error occurred in pipeline write - {str(e)}")
|
print_verbose(f"Error occurred in pipeline write - {str(e)}")
|
||||||
|
|
||||||
|
async def batch_cache_write(self, key, value, **kwargs):
|
||||||
|
print_verbose("in batch cache writing for redis")
|
||||||
|
|
||||||
|
self.redis_batch_writing_buffer.append((key, value))
|
||||||
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
|
await self.flush_cache_buffer()
|
||||||
|
|
||||||
|
async def flush_cache_buffer(self):
|
||||||
|
print_verbose(
|
||||||
|
"flushing to redis....reached size of buffer",
|
||||||
|
len(self.redis_batch_writing_buffer),
|
||||||
|
)
|
||||||
|
await self.async_set_cache_pipeline(self.redis_batch_writing_buffer)
|
||||||
|
self.redis_batch_writing_buffer = []
|
||||||
|
|
||||||
def _get_cache_logic(self, cached_response: Any):
|
def _get_cache_logic(self, cached_response: Any):
|
||||||
"""
|
"""
|
||||||
Common 'get_cache_logic' across sync + async redis client implementations
|
Common 'get_cache_logic' across sync + async redis client implementations
|
||||||
|
@ -908,6 +932,7 @@ class Cache:
|
||||||
s3_path: Optional[str] = None,
|
s3_path: Optional[str] = None,
|
||||||
redis_semantic_cache_use_async=False,
|
redis_semantic_cache_use_async=False,
|
||||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
|
redis_flush_size=100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -930,7 +955,9 @@ class Cache:
|
||||||
None. Cache is set as a litellm param
|
None. Cache is set as a litellm param
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
self.cache: BaseCache = RedisCache(
|
||||||
|
host, port, password, redis_flush_size, **kwargs
|
||||||
|
)
|
||||||
elif type == "redis-semantic":
|
elif type == "redis-semantic":
|
||||||
self.cache = RedisSemanticCache(
|
self.cache = RedisSemanticCache(
|
||||||
host,
|
host,
|
||||||
|
@ -1287,6 +1314,12 @@ class Cache:
|
||||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def batch_cache_write(self, result, *args, **kwargs):
|
||||||
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||||
|
result=result, *args, **kwargs
|
||||||
|
)
|
||||||
|
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||||
|
|
||||||
async def ping(self):
|
async def ping(self):
|
||||||
if hasattr(self.cache, "ping"):
|
if hasattr(self.cache, "ping"):
|
||||||
return await self.cache.ping()
|
return await self.cache.ping()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue