fix(caching.py): use bulk writes and blockconnectionpooling for reads from Redis

This commit is contained in:
Krrish Dholakia 2024-01-13 11:50:50 +05:30
parent 007870390d
commit 01df37d8cf
3 changed files with 109 additions and 19 deletions

View file

@ -106,4 +106,16 @@ def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.Redis.from_url(**redis_kwargs)
return async_redis.Redis(socket_timeout=5, **redis_kwargs)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
)
def get_redis_connection_pool(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"]
)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -8,7 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan
import litellm
import time, logging
import time, logging, asyncio
import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any
from openai._models import BaseModel as OpenAIObject
@ -82,7 +82,7 @@ class RedisCache(BaseCache):
# if users don't provider one, use the default litellm cache
def __init__(self, host=None, port=None, password=None, **kwargs):
from ._redis import get_redis_client
from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {}
if host is not None:
@ -95,11 +95,20 @@ class RedisCache(BaseCache):
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool()
print_verbose(
f"Number of available connections init: {self.async_redis_conn_pool.pool.qsize()}"
)
def init_async_client(self):
from ._redis import get_redis_async_client
return get_redis_async_client(**self.redis_kwargs)
print_verbose(
f"Number of available connections client_init: {self.async_redis_conn_pool.pool.qsize()}"
)
return get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
@ -111,16 +120,52 @@ class RedisCache(BaseCache):
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs):
async with self.init_async_client() as redis_client:
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None)
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.set(name=key, value=str(value), ex=ttl)
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
print_verbose(
f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
"""
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
)
# Set the value with a TTL if it's provided.
if ttl is not None:
pipe.setex(cache_key, ttl, json.dumps(cache_value))
else:
pipe.set(cache_key, json.dumps(cache_value))
# Execute the pipeline and return the results.
results = await pipe.execute()
print_verbose(
f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
)
print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful.
return results
except Exception as e:
print_verbose(f"Error occurred in pipeline write - {str(e)}")
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any):
"""
@ -152,7 +197,8 @@ class RedisCache(BaseCache):
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs):
async with self.init_async_client() as redis_client:
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
@ -166,6 +212,10 @@ class RedisCache(BaseCache):
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
print_verbose(
f"Number of available connections get_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
)
def flush_cache(self):
self.redis_client.flushall()
@ -684,6 +734,39 @@ class Cache:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def async_add_cache_pipeline(self, result, *args, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
embedding_response = result.data[idx]
cache_key, cached_data = self._add_cache_logic(
result=embedding_response,
cache_key=preset_cache_key,
*args,
**kwargs,
)
cache_list.append((cache_key, cached_data))
if hasattr(self.cache, "async_set_cache_pipeline"):
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
else:
tasks = []
for val in cache_list:
tasks.append(
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
)
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()

View file

@ -2346,6 +2346,9 @@ def client(original_function):
kwargs["input"] = remaining_list
if len(non_null_list) > 0:
print_verbose(
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
)
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
@ -2451,19 +2454,11 @@ def client(original_function):
if isinstance(result, EmbeddingResponse) and isinstance(
kwargs["input"], list
):
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(
result, *args, **kwargs
)
embedding_response = result.data[idx]
asyncio.create_task(
litellm.cache.async_add_cache(
embedding_response,
*args,
cache_key=preset_cache_key,
)
)
# pass
)
else:
asyncio.create_task(
litellm.cache.async_add_cache(