From b2e7866ea94a4beb78d92b683e5c72a8dd32959a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:20:29 -0700 Subject: [PATCH] fix(caching.py): respect redis namespace for all redis get/set requests --- litellm/caching.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index eea687edfe..ded347c10d 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -100,7 +100,13 @@ class RedisCache(BaseCache): # if users don't provider one, use the default litellm cache def __init__( - self, host=None, port=None, password=None, redis_flush_size=100, **kwargs + self, + host=None, + port=None, + password=None, + redis_flush_size=100, + namespace: Optional[str] = None, + **kwargs, ): from ._redis import get_redis_client, get_redis_connection_pool @@ -116,9 +122,10 @@ class RedisCache(BaseCache): self.redis_client = get_redis_client(**redis_kwargs) self.redis_kwargs = redis_kwargs self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) - + # redis namespaces + self.namespace = namespace # for high traffic, we store the redis results in memory and then batch write to redis - self.redis_batch_writing_buffer = [] + self.redis_batch_writing_buffer: list = [] self.redis_flush_size = redis_flush_size self.redis_version = "Unknown" try: @@ -133,11 +140,21 @@ class RedisCache(BaseCache): connection_pool=self.async_redis_conn_pool, **self.redis_kwargs ) + def check_and_fix_namespace(self, key: str) -> str: + """ + Make sure each key starts with the given namespace + """ + if self.namespace is not None and not key.startswith(self.namespace): + key = self.namespace + ":" + key + + return key + def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) print_verbose( f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" ) + key = self.check_and_fix_namespace(key=key) try: self.redis_client.set(name=key, value=str(value), ex=ttl) except Exception as e: @@ -158,6 +175,7 @@ class RedisCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): _redis_client = self.init_async_client() + key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: ttl = kwargs.get("ttl", None) print_verbose( @@ -187,6 +205,7 @@ class RedisCache(BaseCache): async with redis_client.pipeline(transaction=True) as pipe: # Iterate through each key-value pair in the cache_list and set them in the pipeline. for cache_key, cache_value in cache_list: + cache_key = self.check_and_fix_namespace(key=cache_key) print_verbose( f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" ) @@ -213,6 +232,7 @@ class RedisCache(BaseCache): print_verbose( f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", ) + key = self.check_and_fix_namespace(key=key) self.redis_batch_writing_buffer.append((key, value)) if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() @@ -242,6 +262,7 @@ class RedisCache(BaseCache): def get_cache(self, key, **kwargs): try: + key = self.check_and_fix_namespace(key=key) print_verbose(f"Get Redis Cache: key: {key}") cached_response = self.redis_client.get(key) print_verbose( @@ -255,6 +276,7 @@ class RedisCache(BaseCache): async def async_get_cache(self, key, **kwargs): _redis_client = self.init_async_client() + key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: try: print_verbose(f"Get Async Redis Cache: key: {key}") @@ -281,6 +303,7 @@ class RedisCache(BaseCache): async with redis_client.pipeline(transaction=True) as pipe: # Queue the get operations in the pipeline for all keys. for cache_key in key_list: + cache_key = self.check_and_fix_namespace(key=cache_key) pipe.get(cache_key) # Queue GET command in pipeline # Execute the pipeline and await the results. @@ -1015,6 +1038,9 @@ class Cache: self.redis_flush_size = redis_flush_size self.ttl = ttl + if self.namespace is not None and isinstance(self.cache, RedisCache): + self.cache.namespace = self.namespace + def get_cache_key(self, *args, **kwargs): """ Get the cache key for the given arguments.