fix(caching.py): respect redis namespace for all redis get/set requests

This commit is contained in:
Krrish Dholakia 2024-03-30 20:20:29 -07:00
parent 7738107d49
commit b2e7866ea9

View file

@ -100,7 +100,13 @@ class RedisCache(BaseCache):
# if users don't provider one, use the default litellm cache # if users don't provider one, use the default litellm cache
def __init__( def __init__(
self, host=None, port=None, password=None, redis_flush_size=100, **kwargs self,
host=None,
port=None,
password=None,
redis_flush_size=100,
namespace: Optional[str] = None,
**kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool from ._redis import get_redis_client, get_redis_connection_pool
@ -116,9 +122,10 @@ class RedisCache(BaseCache):
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
# redis namespaces
self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis # for high traffic, we store the redis results in memory and then batch write to redis
self.redis_batch_writing_buffer = [] self.redis_batch_writing_buffer: list = []
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown" self.redis_version = "Unknown"
try: try:
@ -133,11 +140,21 @@ class RedisCache(BaseCache):
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
) )
def check_and_fix_namespace(self, key: str) -> str:
"""
Make sure each key starts with the given namespace
"""
if self.namespace is not None and not key.startswith(self.namespace):
key = self.namespace + ":" + key
return key
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
print_verbose( print_verbose(
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
) )
key = self.check_and_fix_namespace(key=key)
try: try:
self.redis_client.set(name=key, value=str(value), ex=ttl) self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e: except Exception as e:
@ -158,6 +175,7 @@ class RedisCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client: async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
print_verbose( print_verbose(
@ -187,6 +205,7 @@ class RedisCache(BaseCache):
async with redis_client.pipeline(transaction=True) as pipe: async with redis_client.pipeline(transaction=True) as pipe:
# Iterate through each key-value pair in the cache_list and set them in the pipeline. # Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list: for cache_key, cache_value in cache_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
print_verbose( print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
) )
@ -213,6 +232,7 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
) )
key = self.check_and_fix_namespace(key=key)
self.redis_batch_writing_buffer.append((key, value)) self.redis_batch_writing_buffer.append((key, value))
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() await self.flush_cache_buffer()
@ -242,6 +262,7 @@ class RedisCache(BaseCache):
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
try: try:
key = self.check_and_fix_namespace(key=key)
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key) cached_response = self.redis_client.get(key)
print_verbose( print_verbose(
@ -255,6 +276,7 @@ class RedisCache(BaseCache):
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client: async with _redis_client as redis_client:
try: try:
print_verbose(f"Get Async Redis Cache: key: {key}") print_verbose(f"Get Async Redis Cache: key: {key}")
@ -281,6 +303,7 @@ class RedisCache(BaseCache):
async with redis_client.pipeline(transaction=True) as pipe: async with redis_client.pipeline(transaction=True) as pipe:
# Queue the get operations in the pipeline for all keys. # Queue the get operations in the pipeline for all keys.
for cache_key in key_list: for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
pipe.get(cache_key) # Queue GET command in pipeline pipe.get(cache_key) # Queue GET command in pipeline
# Execute the pipeline and await the results. # Execute the pipeline and await the results.
@ -1015,6 +1038,9 @@ class Cache:
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.ttl = ttl self.ttl = ttl
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
""" """
Get the cache key for the given arguments. Get the cache key for the given arguments.