fix: support async redis caching

This commit is contained in:
Krrish Dholakia 2024-01-12 21:46:41 +05:30
parent 817a3d29b7
commit 007870390d
6 changed files with 357 additions and 122 deletions

View file

@ -11,6 +11,7 @@
import os
import inspect
import redis, litellm
import redis.asyncio as async_redis
from typing import List, Optional
@ -67,7 +68,10 @@ def get_redis_url_from_environment():
)
def get_redis_client(**env_overrides):
def _get_redis_client_logic(**env_overrides):
"""
Common functionality across sync + async redis client implementations
"""
### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
@ -85,9 +89,21 @@ def get_redis_client(**env_overrides):
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return redis.Redis.from_url(**redis_kwargs)
return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.Redis.from_url(**redis_kwargs)
return async_redis.Redis(socket_timeout=5, **redis_kwargs)

View file

@ -26,9 +26,18 @@ class BaseCache:
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
def get_cache(self, key, **kwargs):
raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError
class InMemoryCache(BaseCache):
def __init__(self):
@ -41,6 +50,9 @@ class InMemoryCache(BaseCache):
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
@ -55,16 +67,21 @@ class InMemoryCache(BaseCache):
return cached_response
return None
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs):
import redis
# if users don't provider one, use the default litellm cache
# if users don't provider one, use the default litellm cache
def __init__(self, host=None, port=None, password=None, **kwargs):
from ._redis import get_redis_client
redis_kwargs = {}
@ -76,8 +93,13 @@ class RedisCache(BaseCache):
redis_kwargs["password"] = password
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
def init_async_client(self):
from ._redis import get_redis_async_client
return get_redis_async_client(**self.redis_kwargs)
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
@ -88,6 +110,34 @@ class RedisCache(BaseCache):
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs):
async with self.init_async_client() as redis_client:
ttl = kwargs.get("ttl", None)
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
"""
if cached_response is None:
return cached_response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
def get_cache(self, key, **kwargs):
try:
print_verbose(f"Get Redis Cache: key: {key}")
@ -95,26 +145,33 @@ class RedisCache(BaseCache):
print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs):
async with self.init_async_client() as redis_client:
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def flush_cache(self):
self.redis_client.flushall()
async def disconnect(self):
pass
class S3Cache(BaseCache):
def __init__(
@ -189,6 +246,9 @@ class S3Cache(BaseCache):
# NON blocking - notify users S3 is throwing an exception
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
import boto3, botocore
@ -229,6 +289,9 @@ class S3Cache(BaseCache):
traceback.print_exc()
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self):
pass
@ -468,6 +531,45 @@ class Cache:
}
time.sleep(0.02)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, *args, **kwargs):
"""
Retrieves the cached result for the given arguments.
@ -490,53 +592,40 @@ class Cache:
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.cache.get_cache(cache_key)
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
print_verbose(
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
)
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
return cached_result
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def add_cache(self, result, *args, **kwargs):
async def async_get_cache(self, *args, **kwargs):
"""
Adds a result to the cache.
Async get cache implementation.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Used for embedding calls in async wrapper
"""
try: # never block execution
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
Returns:
None
def _add_cache_logic(self, result, *args, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
try:
if "cache_key" in kwargs:
@ -555,17 +644,49 @@ class Cache:
if k == "ttl":
kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result}
self.cache.set_cache(cache_key, cached_data, **kwargs)
return cache_key, cached_data
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, *args, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
cache_key, cached_data = self._add_cache_logic(
result=result, *args, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
pass
async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs)
async def async_add_cache(self, result, *args, **kwargs):
"""
Async implementation of add_cache
"""
try:
cache_key, cached_data = self._add_cache_logic(
result=result, *args, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def _async_get_cache(self, *args, **kwargs):
return self.get_cache(*args, **kwargs)
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def enable_cache(

View file

@ -346,7 +346,7 @@ def run_server(
import gunicorn.app.base
except:
raise ImportError(
"Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`"
"uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
)
if config is not None:
@ -427,36 +427,40 @@ def run_server(
f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n"
) # noqa
# Gunicorn Application Class
class StandaloneApplication(gunicorn.app.base.BaseApplication):
def __init__(self, app, options=None):
self.options = options or {} # gunicorn options
self.application = app # FastAPI app
super().__init__()
uvicorn.run(
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
)
def load_config(self):
# note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config
config = {
key: value
for key, value in self.options.items()
if key in self.cfg.settings and value is not None
}
for key, value in config.items():
self.cfg.set(key.lower(), value)
# # Gunicorn Application Class
# class StandaloneApplication(gunicorn.app.base.BaseApplication):
# def __init__(self, app, options=None):
# self.options = options or {} # gunicorn options
# self.application = app # FastAPI app
# super().__init__()
def load(self):
# gunicorn app function
return self.application
# def load_config(self):
# # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config
# config = {
# key: value
# for key, value in self.options.items()
# if key in self.cfg.settings and value is not None
# }
# for key, value in config.items():
# self.cfg.set(key.lower(), value)
gunicorn_options = {
"bind": f"{host}:{port}",
"workers": num_workers, # default is 1
"worker_class": "uvicorn.workers.UvicornWorker",
"preload": True, # Add the preload flag
}
from litellm.proxy.proxy_server import app
# def load(self):
# # gunicorn app function
# return self.application
StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn
# gunicorn_options = {
# "bind": f"{host}:{port}",
# "workers": num_workers, # default is 1
# "worker_class": "uvicorn.workers.UvicornWorker",
# "preload": True, # Add the preload flag
# }
# from litellm.proxy.proxy_server import app
# StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn
if __name__ == "__main__":

View file

@ -7,6 +7,20 @@ import secrets, subprocess
import hashlib, uuid
import warnings
import importlib
import warnings
def showwarning(message, category, filename, lineno, file=None, line=None):
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
if file is not None:
file.write(traceback_info)
warnings.showwarning = showwarning
warnings.filterwarnings("default", category=UserWarning)
# Your client code here
messages: list = []
sys.path.insert(
@ -2510,10 +2524,12 @@ async def get_routes():
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client, master_key, user_custom_auth
if prisma_client:
if prisma_client is not None:
verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect()
if litellm.cache is not None:
await litellm.cache.disconnect()
## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables()

View file

@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items():
"""
Tests caching for individual items in an embedding list
Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls
- Cache an item
- call aembedding(..) with the item + 1 unique item
- compare to a 2nd aembedding (...) with 2 unique items
```
embedding_1 = ["hey how's it going", "I'm doing well"]
embedding_val_1 = embedding(...)
@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items():
"""
litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg, "I'm doing well"]
embedding_2 = [common_msg, "I'm fine"]
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_3 = [
common_msg_2,
common_msg_2,
common_msg_2,
common_msg_2,
f"I'm fine {uuid.uuid4()}",
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True
)
second_response_start_time = time.time()
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True
)
print(f"embedding_val_2: {embedding_val_2}")
if (
embedding_val_2["data"][0]["embedding"]
!= embedding_val_1["data"][0]["embedding"]
):
print(f"embedding1: {embedding_val_1}")
print(f"embedding2: {embedding_val_2}")
pytest.fail("Error occurred: Embedding caching failed")
if (
embedding_val_2["data"][1]["embedding"]
== embedding_val_1["data"][1]["embedding"]
):
print(f"embedding1: {embedding_val_1}")
print(f"embedding2: {embedding_val_2}")
pytest.fail("Error occurred: Embedding caching failed")
if embedding_val_2 is not None:
second_response_end_time = time.time()
second_response_time = second_response_end_time - second_response_start_time
third_response_start_time = time.time()
embedding_val_3 = await aembedding(
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
)
if embedding_val_3 is not None:
third_response_end_time = time.time()
third_response_time = third_response_end_time - third_response_start_time
print(f"second_response_time: {second_response_time}")
print(f"third_response_time: {third_response_time}")
assert (
second_response_time < third_response_time - 0.5
) # make sure it's actually faster
raise Exception(f"it works {second_response_time} < {third_response_time}")
@pytest.mark.asyncio
async def test_redis_cache_basic():
"""
Init redis client
- write to client
- read from client
"""
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
)
cache_key = litellm.cache.get_cache_key(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"cache_key: {cache_key}")
litellm.cache.add_cache(result=response1, cache_key=cache_key)
print(f"cache key pre async get: {cache_key}")
stored_val = await litellm.cache.async_get_cache(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id
raise Exception("it worked!")
def test_redis_cache_completion():

View file

@ -2214,8 +2214,13 @@ def client(original_function):
)
# if caching is false, don't run this
final_embedding_cached_response = None
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
(
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
@ -2234,12 +2239,13 @@ def client(original_function):
kwargs["input"], list
):
tasks = []
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
embedding_kwargs["input"] = i
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append(
litellm.cache._async_get_cache(
*args, **embedding_kwargs
litellm.cache.async_get_cache(
cache_key=preset_cache_key
)
)
cached_result = await asyncio.gather(*tasks)
@ -2445,24 +2451,28 @@ def client(original_function):
if isinstance(result, EmbeddingResponse) and isinstance(
kwargs["input"], list
):
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
embedding_response = result.data[idx]
embedding_kwargs["input"] = i
asyncio.create_task(
litellm.cache._async_add_cache(
embedding_response, *args, **embedding_kwargs
litellm.cache.async_add_cache(
embedding_response,
*args,
cache_key=preset_cache_key,
)
)
# pass
else:
asyncio.create_task(
litellm.cache._async_add_cache(
litellm.cache.async_add_cache(
result.json(), *args, **kwargs
)
)
else:
asyncio.create_task(
litellm.cache._async_add_cache(result, *args, **kwargs)
litellm.cache.async_add_cache(result, *args, **kwargs)
)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(