mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix: support async redis caching
This commit is contained in:
parent
817a3d29b7
commit
007870390d
6 changed files with 357 additions and 122 deletions
|
@ -26,9 +26,18 @@ class BaseCache:
|
|||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
|
@ -41,6 +50,9 @@ class InMemoryCache(BaseCache):
|
|||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
|
@ -55,16 +67,21 @@ class InMemoryCache(BaseCache):
|
|||
return cached_response
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
import redis
|
||||
# if users don't provider one, use the default litellm cache
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
from ._redis import get_redis_client
|
||||
|
||||
redis_kwargs = {}
|
||||
|
@ -76,8 +93,13 @@ class RedisCache(BaseCache):
|
|||
redis_kwargs["password"] = password
|
||||
|
||||
redis_kwargs.update(kwargs)
|
||||
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
self.redis_kwargs = redis_kwargs
|
||||
|
||||
def init_async_client(self):
|
||||
from ._redis import get_redis_async_client
|
||||
|
||||
return get_redis_async_client(**self.redis_kwargs)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
|
@ -88,6 +110,34 @@ class RedisCache(BaseCache):
|
|||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
async with self.init_async_client() as redis_client:
|
||||
ttl = kwargs.get("ttl", None)
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
"""
|
||||
Common 'get_cache_logic' across sync + async redis client implementations
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
|
@ -95,26 +145,33 @@ class RedisCache(BaseCache):
|
|||
print_verbose(
|
||||
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
if cached_response != None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode(
|
||||
"utf-8"
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
async with self.init_async_client() as redis_client:
|
||||
try:
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
cached_response = await redis_client.get(key)
|
||||
print_verbose(
|
||||
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
response = self._get_cache_logic(cached_response=cached_response)
|
||||
return response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
|
@ -189,6 +246,9 @@ class S3Cache(BaseCache):
|
|||
# NON blocking - notify users S3 is throwing an exception
|
||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import boto3, botocore
|
||||
|
||||
|
@ -229,6 +289,9 @@ class S3Cache(BaseCache):
|
|||
traceback.print_exc()
|
||||
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
|
@ -468,6 +531,45 @@ class Cache:
|
|||
}
|
||||
time.sleep(0.02)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
@ -490,53 +592,40 @@ class Cache:
|
|||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
print_verbose(
|
||||
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
|
||||
)
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
return cached_result
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def add_cache(self, result, *args, **kwargs):
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
Async get cache implementation.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
try: # never block execution
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
Returns:
|
||||
None
|
||||
def _add_cache_logic(self, result, *args, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
|
@ -555,17 +644,49 @@ class Cache:
|
|||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
return cache_key, cached_data
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, *args, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
cache_key, cached_data = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
async def _async_add_cache(self, result, *args, **kwargs):
|
||||
self.add_cache(result, *args, **kwargs)
|
||||
async def async_add_cache(self, result, *args, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
cache_key, cached_data = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
)
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _async_get_cache(self, *args, **kwargs):
|
||||
return self.get_cache(*args, **kwargs)
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
|
||||
def enable_cache(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue