mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(caching.py): respect redis namespace for all redis get/set requests
This commit is contained in:
parent
7738107d49
commit
b2e7866ea9
1 changed files with 29 additions and 3 deletions
|
@ -100,7 +100,13 @@ class RedisCache(BaseCache):
|
||||||
# if users don't provider one, use the default litellm cache
|
# if users don't provider one, use the default litellm cache
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, host=None, port=None, password=None, redis_flush_size=100, **kwargs
|
self,
|
||||||
|
host=None,
|
||||||
|
port=None,
|
||||||
|
password=None,
|
||||||
|
redis_flush_size=100,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
from ._redis import get_redis_client, get_redis_connection_pool
|
from ._redis import get_redis_client, get_redis_connection_pool
|
||||||
|
|
||||||
|
@ -116,9 +122,10 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client = get_redis_client(**redis_kwargs)
|
self.redis_client = get_redis_client(**redis_kwargs)
|
||||||
self.redis_kwargs = redis_kwargs
|
self.redis_kwargs = redis_kwargs
|
||||||
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
||||||
|
# redis namespaces
|
||||||
|
self.namespace = namespace
|
||||||
# for high traffic, we store the redis results in memory and then batch write to redis
|
# for high traffic, we store the redis results in memory and then batch write to redis
|
||||||
self.redis_batch_writing_buffer = []
|
self.redis_batch_writing_buffer: list = []
|
||||||
self.redis_flush_size = redis_flush_size
|
self.redis_flush_size = redis_flush_size
|
||||||
self.redis_version = "Unknown"
|
self.redis_version = "Unknown"
|
||||||
try:
|
try:
|
||||||
|
@ -133,11 +140,21 @@ class RedisCache(BaseCache):
|
||||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_and_fix_namespace(self, key: str) -> str:
|
||||||
|
"""
|
||||||
|
Make sure each key starts with the given namespace
|
||||||
|
"""
|
||||||
|
if self.namespace is not None and not key.startswith(self.namespace):
|
||||||
|
key = self.namespace + ":" + key
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
||||||
)
|
)
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
try:
|
try:
|
||||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -158,6 +175,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -187,6 +205,7 @@ class RedisCache(BaseCache):
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=True) as pipe:
|
||||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||||
for cache_key, cache_value in cache_list:
|
for cache_key, cache_value in cache_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
|
@ -213,6 +232,7 @@ class RedisCache(BaseCache):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
||||||
)
|
)
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
self.redis_batch_writing_buffer.append((key, value))
|
self.redis_batch_writing_buffer.append((key, value))
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer()
|
await self.flush_cache_buffer()
|
||||||
|
@ -242,6 +262,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
print_verbose(f"Get Redis Cache: key: {key}")
|
print_verbose(f"Get Redis Cache: key: {key}")
|
||||||
cached_response = self.redis_client.get(key)
|
cached_response = self.redis_client.get(key)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -255,6 +276,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
try:
|
try:
|
||||||
print_verbose(f"Get Async Redis Cache: key: {key}")
|
print_verbose(f"Get Async Redis Cache: key: {key}")
|
||||||
|
@ -281,6 +303,7 @@ class RedisCache(BaseCache):
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=True) as pipe:
|
||||||
# Queue the get operations in the pipeline for all keys.
|
# Queue the get operations in the pipeline for all keys.
|
||||||
for cache_key in key_list:
|
for cache_key in key_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
pipe.get(cache_key) # Queue GET command in pipeline
|
pipe.get(cache_key) # Queue GET command in pipeline
|
||||||
|
|
||||||
# Execute the pipeline and await the results.
|
# Execute the pipeline and await the results.
|
||||||
|
@ -1015,6 +1038,9 @@ class Cache:
|
||||||
self.redis_flush_size = redis_flush_size
|
self.redis_flush_size = redis_flush_size
|
||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
|
|
||||||
|
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||||
|
self.cache.namespace = self.namespace
|
||||||
|
|
||||||
def get_cache_key(self, *args, **kwargs):
|
def get_cache_key(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Get the cache key for the given arguments.
|
Get the cache key for the given arguments.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue