fix: support async redis caching

This commit is contained in:
Krrish Dholakia 2024-01-12 21:46:41 +05:30
parent 817a3d29b7
commit 007870390d
6 changed files with 357 additions and 122 deletions

View file

@ -26,9 +26,18 @@ class BaseCache:
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
def get_cache(self, key, **kwargs):
raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError
class InMemoryCache(BaseCache):
def __init__(self):
@ -41,6 +50,9 @@ class InMemoryCache(BaseCache):
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
@ -55,16 +67,21 @@ class InMemoryCache(BaseCache):
return cached_response
return None
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs):
import redis
# if users don't provider one, use the default litellm cache
# if users don't provider one, use the default litellm cache
def __init__(self, host=None, port=None, password=None, **kwargs):
from ._redis import get_redis_client
redis_kwargs = {}
@ -76,8 +93,13 @@ class RedisCache(BaseCache):
redis_kwargs["password"] = password
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
def init_async_client(self):
from ._redis import get_redis_async_client
return get_redis_async_client(**self.redis_kwargs)
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
@ -88,6 +110,34 @@ class RedisCache(BaseCache):
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs):
async with self.init_async_client() as redis_client:
ttl = kwargs.get("ttl", None)
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
"""
if cached_response is None:
return cached_response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
def get_cache(self, key, **kwargs):
try:
print_verbose(f"Get Redis Cache: key: {key}")
@ -95,26 +145,33 @@ class RedisCache(BaseCache):
print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs):
async with self.init_async_client() as redis_client:
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def flush_cache(self):
self.redis_client.flushall()
async def disconnect(self):
pass
class S3Cache(BaseCache):
def __init__(
@ -189,6 +246,9 @@ class S3Cache(BaseCache):
# NON blocking - notify users S3 is throwing an exception
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
import boto3, botocore
@ -229,6 +289,9 @@ class S3Cache(BaseCache):
traceback.print_exc()
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self):
pass
@ -468,6 +531,45 @@ class Cache:
}
time.sleep(0.02)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, *args, **kwargs):
"""
Retrieves the cached result for the given arguments.
@ -490,53 +592,40 @@ class Cache:
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.cache.get_cache(cache_key)
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
print_verbose(
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
)
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
return cached_result
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def add_cache(self, result, *args, **kwargs):
async def async_get_cache(self, *args, **kwargs):
"""
Adds a result to the cache.
Async get cache implementation.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Used for embedding calls in async wrapper
"""
try: # never block execution
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
Returns:
None
def _add_cache_logic(self, result, *args, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
try:
if "cache_key" in kwargs:
@ -555,17 +644,49 @@ class Cache:
if k == "ttl":
kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result}
self.cache.set_cache(cache_key, cached_data, **kwargs)
return cache_key, cached_data
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, *args, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
cache_key, cached_data = self._add_cache_logic(
result=result, *args, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
pass
async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs)
async def async_add_cache(self, result, *args, **kwargs):
"""
Async implementation of add_cache
"""
try:
cache_key, cached_data = self._add_cache_logic(
result=result, *args, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def _async_get_cache(self, *args, **kwargs):
return self.get_cache(*args, **kwargs)
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def enable_cache(