From 01df37d8cfdde067c8de7749d4a2e0ae7a89e059 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 11:50:50 +0530 Subject: [PATCH] fix(caching.py): use bulk writes and blockconnectionpooling for reads from Redis --- litellm/_redis.py | 14 ++++++- litellm/caching.py | 95 +++++++++++++++++++++++++++++++++++++++++++--- litellm/utils.py | 19 ++++------ 3 files changed, 109 insertions(+), 19 deletions(-) diff --git a/litellm/_redis.py b/litellm/_redis.py index 36f4ef870..4484926d4 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -106,4 +106,16 @@ def get_redis_async_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: return async_redis.Redis.from_url(**redis_kwargs) - return async_redis.Redis(socket_timeout=5, **redis_kwargs) + return async_redis.Redis( + socket_timeout=5, + **redis_kwargs, + ) + + +def get_redis_connection_pool(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return async_redis.BlockingConnectionPool.from_url( + timeout=5, url=redis_kwargs["url"] + ) + return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index b89220e8d..de3b02297 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -8,7 +8,7 @@ # Thank you users! We ❤️ you! - Krrish & Ishaan import litellm -import time, logging +import time, logging, asyncio import json, traceback, ast, hashlib from typing import Optional, Literal, List, Union, Any from openai._models import BaseModel as OpenAIObject @@ -82,7 +82,7 @@ class RedisCache(BaseCache): # if users don't provider one, use the default litellm cache def __init__(self, host=None, port=None, password=None, **kwargs): - from ._redis import get_redis_client + from ._redis import get_redis_client, get_redis_connection_pool redis_kwargs = {} if host is not None: @@ -95,11 +95,20 @@ class RedisCache(BaseCache): redis_kwargs.update(kwargs) self.redis_client = get_redis_client(**redis_kwargs) self.redis_kwargs = redis_kwargs + self.async_redis_conn_pool = get_redis_connection_pool() + print_verbose( + f"Number of available connections init: {self.async_redis_conn_pool.pool.qsize()}" + ) def init_async_client(self): from ._redis import get_redis_async_client - return get_redis_async_client(**self.redis_kwargs) + print_verbose( + f"Number of available connections client_init: {self.async_redis_conn_pool.pool.qsize()}" + ) + return get_redis_async_client( + connection_pool=self.async_redis_conn_pool, **self.redis_kwargs + ) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) @@ -111,16 +120,52 @@ class RedisCache(BaseCache): logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) async def async_set_cache(self, key, value, **kwargs): - async with self.init_async_client() as redis_client: + _redis_client = self.init_async_client() + async with _redis_client as redis_client: ttl = kwargs.get("ttl", None) print_verbose( f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) try: - await redis_client.set(name=key, value=str(value), ex=ttl) + await redis_client.set(name=key, value=json.dumps(value), ex=ttl) except Exception as e: # NON blocking - notify users Redis is throwing an exception logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + print_verbose( + f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}" + ) + + async def async_set_cache_pipeline(self, cache_list, ttl=None): + """ + Use Redis Pipelines for bulk write operations + """ + _redis_client = self.init_async_client() + try: + async with _redis_client as redis_client: + async with redis_client.pipeline(transaction=True) as pipe: + # Iterate through each key-value pair in the cache_list and set them in the pipeline. + for cache_key, cache_value in cache_list: + print_verbose( + f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" + ) + # Set the value with a TTL if it's provided. + if ttl is not None: + pipe.setex(cache_key, ttl, json.dumps(cache_value)) + else: + pipe.set(cache_key, json.dumps(cache_value)) + # Execute the pipeline and return the results. + results = await pipe.execute() + print_verbose( + f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}" + ) + + print_verbose(f"pipeline results: {results}") + # Optionally, you could process 'results' to make sure that all set operations were successful. + return results + except Exception as e: + print_verbose(f"Error occurred in pipeline write - {str(e)}") + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) def _get_cache_logic(self, cached_response: Any): """ @@ -152,7 +197,8 @@ class RedisCache(BaseCache): logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) async def async_get_cache(self, key, **kwargs): - async with self.init_async_client() as redis_client: + _redis_client = self.init_async_client() + async with _redis_client as redis_client: try: print_verbose(f"Get Redis Cache: key: {key}") cached_response = await redis_client.get(key) @@ -166,6 +212,10 @@ class RedisCache(BaseCache): traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + print_verbose( + f"Number of available connections get_cache complete: {self.async_redis_conn_pool.pool.qsize()}" + ) + def flush_cache(self): self.redis_client.flushall() @@ -684,6 +734,39 @@ class Cache: print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") traceback.print_exc() + async def async_add_cache_pipeline(self, result, *args, **kwargs): + """ + Async implementation of add_cache for Embedding calls + + Does a bulk write, to prevent using too many clients + """ + try: + cache_list = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) + embedding_response = result.data[idx] + cache_key, cached_data = self._add_cache_logic( + result=embedding_response, + cache_key=preset_cache_key, + *args, + **kwargs, + ) + cache_list.append((cache_key, cached_data)) + if hasattr(self.cache, "async_set_cache_pipeline"): + await self.cache.async_set_cache_pipeline(cache_list=cache_list) + else: + tasks = [] + for val in cache_list: + tasks.append( + self.cache.async_set_cache(cache_key, cached_data, **kwargs) + ) + await asyncio.gather(*tasks) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + traceback.print_exc() + async def disconnect(self): if hasattr(self.cache, "disconnect"): await self.cache.disconnect() diff --git a/litellm/utils.py b/litellm/utils.py index 84a81649a..6059cdf85 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2346,6 +2346,9 @@ def client(original_function): kwargs["input"] = remaining_list if len(non_null_list) > 0: + print_verbose( + f"EMBEDDING CACHE HIT! - {len(non_null_list)}" + ) final_embedding_cached_response = EmbeddingResponse( model=kwargs.get("model"), data=[None] * len(original_kwargs_input), @@ -2451,19 +2454,11 @@ def client(original_function): if isinstance(result, EmbeddingResponse) and isinstance( kwargs["input"], list ): - for idx, i in enumerate(kwargs["input"]): - preset_cache_key = litellm.cache.get_cache_key( - *args, **{**kwargs, "input": i} + asyncio.create_task( + litellm.cache.async_add_cache_pipeline( + result, *args, **kwargs ) - embedding_response = result.data[idx] - asyncio.create_task( - litellm.cache.async_add_cache( - embedding_response, - *args, - cache_key=preset_cache_key, - ) - ) - # pass + ) else: asyncio.create_task( litellm.cache.async_add_cache(