fix(caching.py): use bulk writes and blockconnectionpooling for reads from Redis

This commit is contained in:
Krrish Dholakia 2024-01-13 11:50:50 +05:30
parent 007870390d
commit 01df37d8cf
3 changed files with 109 additions and 19 deletions

View file

@ -106,4 +106,16 @@ def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides) redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None: if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.Redis.from_url(**redis_kwargs) return async_redis.Redis.from_url(**redis_kwargs)
return async_redis.Redis(socket_timeout=5, **redis_kwargs) return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
)
def get_redis_connection_pool(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"]
)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -8,7 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import litellm import litellm
import time, logging import time, logging, asyncio
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any from typing import Optional, Literal, List, Union, Any
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
@ -82,7 +82,7 @@ class RedisCache(BaseCache):
# if users don't provider one, use the default litellm cache # if users don't provider one, use the default litellm cache
def __init__(self, host=None, port=None, password=None, **kwargs): def __init__(self, host=None, port=None, password=None, **kwargs):
from ._redis import get_redis_client from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {} redis_kwargs = {}
if host is not None: if host is not None:
@ -95,11 +95,20 @@ class RedisCache(BaseCache):
redis_kwargs.update(kwargs) redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool()
print_verbose(
f"Number of available connections init: {self.async_redis_conn_pool.pool.qsize()}"
)
def init_async_client(self): def init_async_client(self):
from ._redis import get_redis_async_client from ._redis import get_redis_async_client
return get_redis_async_client(**self.redis_kwargs) print_verbose(
f"Number of available connections client_init: {self.async_redis_conn_pool.pool.qsize()}"
)
return get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
@ -111,16 +120,52 @@ class RedisCache(BaseCache):
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
async with self.init_async_client() as redis_client: _redis_client = self.init_async_client()
async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
print_verbose( print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
) )
try: try:
await redis_client.set(name=key, value=str(value), ex=ttl) await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
print_verbose(
f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
"""
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
)
# Set the value with a TTL if it's provided.
if ttl is not None:
pipe.setex(cache_key, ttl, json.dumps(cache_value))
else:
pipe.set(cache_key, json.dumps(cache_value))
# Execute the pipeline and return the results.
results = await pipe.execute()
print_verbose(
f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
)
print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful.
return results
except Exception as e:
print_verbose(f"Error occurred in pipeline write - {str(e)}")
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any): def _get_cache_logic(self, cached_response: Any):
""" """
@ -152,7 +197,8 @@ class RedisCache(BaseCache):
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
async with self.init_async_client() as redis_client: _redis_client = self.init_async_client()
async with _redis_client as redis_client:
try: try:
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
cached_response = await redis_client.get(key) cached_response = await redis_client.get(key)
@ -166,6 +212,10 @@ class RedisCache(BaseCache):
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
print_verbose(
f"Number of available connections get_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
)
def flush_cache(self): def flush_cache(self):
self.redis_client.flushall() self.redis_client.flushall()
@ -684,6 +734,39 @@ class Cache:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
async def async_add_cache_pipeline(self, result, *args, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
embedding_response = result.data[idx]
cache_key, cached_data = self._add_cache_logic(
result=embedding_response,
cache_key=preset_cache_key,
*args,
**kwargs,
)
cache_list.append((cache_key, cached_data))
if hasattr(self.cache, "async_set_cache_pipeline"):
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
else:
tasks = []
for val in cache_list:
tasks.append(
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
)
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def disconnect(self): async def disconnect(self):
if hasattr(self.cache, "disconnect"): if hasattr(self.cache, "disconnect"):
await self.cache.disconnect() await self.cache.disconnect()

View file

@ -2346,6 +2346,9 @@ def client(original_function):
kwargs["input"] = remaining_list kwargs["input"] = remaining_list
if len(non_null_list) > 0: if len(non_null_list) > 0:
print_verbose(
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
)
final_embedding_cached_response = EmbeddingResponse( final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"), model=kwargs.get("model"),
data=[None] * len(original_kwargs_input), data=[None] * len(original_kwargs_input),
@ -2451,19 +2454,11 @@ def client(original_function):
if isinstance(result, EmbeddingResponse) and isinstance( if isinstance(result, EmbeddingResponse) and isinstance(
kwargs["input"], list kwargs["input"], list
): ):
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
embedding_response = result.data[idx]
asyncio.create_task( asyncio.create_task(
litellm.cache.async_add_cache( litellm.cache.async_add_cache_pipeline(
embedding_response, result, *args, **kwargs
*args,
cache_key=preset_cache_key,
) )
) )
# pass
else: else:
asyncio.create_task( asyncio.create_task(
litellm.cache.async_add_cache( litellm.cache.async_add_cache(