diff --git a/litellm/_redis.py b/litellm/_redis.py index bee73f134..36f4ef870 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -11,6 +11,7 @@ import os import inspect import redis, litellm +import redis.asyncio as async_redis from typing import List, Optional @@ -67,7 +68,10 @@ def get_redis_url_from_environment(): ) -def get_redis_client(**env_overrides): +def _get_redis_client_logic(**env_overrides): + """ + Common functionality across sync + async redis client implementations + """ ### check if "os.environ/" passed in for k, v in env_overrides.items(): if isinstance(v, str) and v.startswith("os.environ/"): @@ -85,9 +89,21 @@ def get_redis_client(**env_overrides): redis_kwargs.pop("port", None) redis_kwargs.pop("db", None) redis_kwargs.pop("password", None) - - return redis.Redis.from_url(**redis_kwargs) elif "host" not in redis_kwargs or redis_kwargs["host"] is None: raise ValueError("Either 'host' or 'url' must be specified for redis.") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") + return redis_kwargs + + +def get_redis_client(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return redis.Redis.from_url(**redis_kwargs) return redis.Redis(**redis_kwargs) + + +def get_redis_async_client(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return async_redis.Redis.from_url(**redis_kwargs) + return async_redis.Redis(socket_timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index 2c01a17c6..b89220e8d 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -26,9 +26,18 @@ class BaseCache: def set_cache(self, key, value, **kwargs): raise NotImplementedError + async def async_set_cache(self, key, value, **kwargs): + raise NotImplementedError + def get_cache(self, key, **kwargs): raise NotImplementedError + async def async_get_cache(self, key, **kwargs): + raise NotImplementedError + + async def disconnect(self): + raise NotImplementedError + class InMemoryCache(BaseCache): def __init__(self): @@ -41,6 +50,9 @@ class InMemoryCache(BaseCache): if "ttl" in kwargs: self.ttl_dict[key] = time.time() + kwargs["ttl"] + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -55,16 +67,21 @@ class InMemoryCache(BaseCache): return cached_response return None + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + def flush_cache(self): self.cache_dict.clear() self.ttl_dict.clear() + async def disconnect(self): + pass + class RedisCache(BaseCache): - def __init__(self, host=None, port=None, password=None, **kwargs): - import redis + # if users don't provider one, use the default litellm cache - # if users don't provider one, use the default litellm cache + def __init__(self, host=None, port=None, password=None, **kwargs): from ._redis import get_redis_client redis_kwargs = {} @@ -76,8 +93,13 @@ class RedisCache(BaseCache): redis_kwargs["password"] = password redis_kwargs.update(kwargs) - self.redis_client = get_redis_client(**redis_kwargs) + self.redis_kwargs = redis_kwargs + + def init_async_client(self): + from ._redis import get_redis_async_client + + return get_redis_async_client(**self.redis_kwargs) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) @@ -88,6 +110,34 @@ class RedisCache(BaseCache): # NON blocking - notify users Redis is throwing an exception logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + async def async_set_cache(self, key, value, **kwargs): + async with self.init_async_client() as redis_client: + ttl = kwargs.get("ttl", None) + print_verbose( + f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) + try: + await redis_client.set(name=key, value=str(value), ex=ttl) + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + + def _get_cache_logic(self, cached_response: Any): + """ + Common 'get_cache_logic' across sync + async redis client implementations + """ + if cached_response is None: + return cached_response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_response.decode("utf-8") # Convert bytes to string + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) + return cached_response + def get_cache(self, key, **kwargs): try: print_verbose(f"Get Redis Cache: key: {key}") @@ -95,26 +145,33 @@ class RedisCache(BaseCache): print_verbose( f"Got Redis Cache: key: {key}, cached_response {cached_response}" ) - if cached_response != None: - # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_response.decode( - "utf-8" - ) # Convert bytes to string - try: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary - except: - cached_response = ast.literal_eval(cached_response) - return cached_response + return self._get_cache_logic(cached_response=cached_response) except Exception as e: # NON blocking - notify users Redis is throwing an exception traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + async def async_get_cache(self, key, **kwargs): + async with self.init_async_client() as redis_client: + try: + print_verbose(f"Get Redis Cache: key: {key}") + cached_response = await redis_client.get(key) + print_verbose( + f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" + ) + response = self._get_cache_logic(cached_response=cached_response) + return response + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + traceback.print_exc() + logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + def flush_cache(self): self.redis_client.flushall() + async def disconnect(self): + pass + class S3Cache(BaseCache): def __init__( @@ -189,6 +246,9 @@ class S3Cache(BaseCache): # NON blocking - notify users S3 is throwing an exception print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + def get_cache(self, key, **kwargs): import boto3, botocore @@ -229,6 +289,9 @@ class S3Cache(BaseCache): traceback.print_exc() print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + def flush_cache(self): pass @@ -468,6 +531,45 @@ class Cache: } time.sleep(0.02) + def _get_cache_logic( + self, + cached_result: Optional[Any], + max_age: Optional[float], + ): + """ + Common get cache logic across sync + async implementations + """ + # Check if a timestamp was stored with the cached response + if ( + cached_result is not None + and isinstance(cached_result, dict) + and "timestamp" in cached_result + ): + timestamp = cached_result["timestamp"] + current_time = time.time() + + # Calculate age of the cached response + response_age = current_time - timestamp + + # Check if the cached response is older than the max-age + if max_age is not None and response_age > max_age: + return None # Cached response is too old + + # If the response is fresh, or there's no max-age requirement, return the cached response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_result.get("response") + try: + if isinstance(cached_response, dict): + pass + else: + cached_response = json.loads( + cached_response # type: ignore + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) # type: ignore + return cached_response + return cached_result + def get_cache(self, *args, **kwargs): """ Retrieves the cached result for the given arguments. @@ -490,53 +592,40 @@ class Cache: "s-max-age", cache_control_args.get("s-maxage", float("inf")) ) cached_result = self.cache.get_cache(cache_key) - # Check if a timestamp was stored with the cached response - if ( - cached_result is not None - and isinstance(cached_result, dict) - and "timestamp" in cached_result - ): - timestamp = cached_result["timestamp"] - current_time = time.time() - - # Calculate age of the cached response - response_age = current_time - timestamp - - # Check if the cached response is older than the max-age - if max_age is not None and response_age > max_age: - print_verbose( - f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s" - ) - return None # Cached response is too old - - # If the response is fresh, or there's no max-age requirement, return the cached response - # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_result.get("response") - try: - if isinstance(cached_response, dict): - pass - else: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary - except: - cached_response = ast.literal_eval(cached_response) - return cached_response - return cached_result + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) except Exception as e: print_verbose(f"An exception occurred: {traceback.format_exc()}") return None - def add_cache(self, result, *args, **kwargs): + async def async_get_cache(self, *args, **kwargs): """ - Adds a result to the cache. + Async get cache implementation. - Args: - *args: args to litellm.completion() or embedding() - **kwargs: kwargs to litellm.completion() or embedding() + Used for embedding calls in async wrapper + """ + try: # never block execution + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(*args, **kwargs) + if cache_key is not None: + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = await self.cache.async_get_cache(cache_key) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception as e: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None - Returns: - None + def _add_cache_logic(self, result, *args, **kwargs): + """ + Common implementation across sync + async add_cache functions """ try: if "cache_key" in kwargs: @@ -555,17 +644,49 @@ class Cache: if k == "ttl": kwargs["ttl"] = v cached_data = {"timestamp": time.time(), "response": result} - self.cache.set_cache(cache_key, cached_data, **kwargs) + return cache_key, cached_data + else: + raise Exception("cache key is None") + except Exception as e: + raise e + + def add_cache(self, result, *args, **kwargs): + """ + Adds a result to the cache. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + None + """ + try: + cache_key, cached_data = self._add_cache_logic( + result=result, *args, **kwargs + ) + self.cache.set_cache(cache_key, cached_data, **kwargs) except Exception as e: print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") traceback.print_exc() pass - async def _async_add_cache(self, result, *args, **kwargs): - self.add_cache(result, *args, **kwargs) + async def async_add_cache(self, result, *args, **kwargs): + """ + Async implementation of add_cache + """ + try: + cache_key, cached_data = self._add_cache_logic( + result=result, *args, **kwargs + ) + await self.cache.async_set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + traceback.print_exc() - async def _async_get_cache(self, *args, **kwargs): - return self.get_cache(*args, **kwargs) + async def disconnect(self): + if hasattr(self.cache, "disconnect"): + await self.cache.disconnect() def enable_cache( diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index c06ba7d32..9a1a01d66 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -346,7 +346,7 @@ def run_server( import gunicorn.app.base except: raise ImportError( - "Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`" + "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`" ) if config is not None: @@ -427,36 +427,40 @@ def run_server( f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n" ) # noqa - # Gunicorn Application Class - class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} # gunicorn options - self.application = app # FastAPI app - super().__init__() + uvicorn.run( + "litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers + ) - def load_config(self): - # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config - config = { - key: value - for key, value in self.options.items() - if key in self.cfg.settings and value is not None - } - for key, value in config.items(): - self.cfg.set(key.lower(), value) + # # Gunicorn Application Class + # class StandaloneApplication(gunicorn.app.base.BaseApplication): + # def __init__(self, app, options=None): + # self.options = options or {} # gunicorn options + # self.application = app # FastAPI app + # super().__init__() - def load(self): - # gunicorn app function - return self.application + # def load_config(self): + # # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config + # config = { + # key: value + # for key, value in self.options.items() + # if key in self.cfg.settings and value is not None + # } + # for key, value in config.items(): + # self.cfg.set(key.lower(), value) - gunicorn_options = { - "bind": f"{host}:{port}", - "workers": num_workers, # default is 1 - "worker_class": "uvicorn.workers.UvicornWorker", - "preload": True, # Add the preload flag - } - from litellm.proxy.proxy_server import app + # def load(self): + # # gunicorn app function + # return self.application - StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn + # gunicorn_options = { + # "bind": f"{host}:{port}", + # "workers": num_workers, # default is 1 + # "worker_class": "uvicorn.workers.UvicornWorker", + # "preload": True, # Add the preload flag + # } + # from litellm.proxy.proxy_server import app + + # StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn if __name__ == "__main__": diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e74314193..8fd62cde2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7,6 +7,20 @@ import secrets, subprocess import hashlib, uuid import warnings import importlib +import warnings + + +def showwarning(message, category, filename, lineno, file=None, line=None): + traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n" + if file is not None: + file.write(traceback_info) + + +warnings.showwarning = showwarning +warnings.filterwarnings("default", category=UserWarning) + +# Your client code here + messages: list = [] sys.path.insert( @@ -2510,10 +2524,12 @@ async def get_routes(): @router.on_event("shutdown") async def shutdown_event(): global prisma_client, master_key, user_custom_auth - if prisma_client: + if prisma_client is not None: verbose_proxy_logger.debug("Disconnecting from Prisma") await prisma_client.disconnect() + if litellm.cache is not None: + await litellm.cache.disconnect() ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables() diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 11d4fda15..3250a2621 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items(): """ Tests caching for individual items in an embedding list - Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls - + - Cache an item + - call aembedding(..) with the item + 1 unique item + - compare to a 2nd aembedding (...) with 2 unique items ``` embedding_1 = ["hey how's it going", "I'm doing well"] embedding_val_1 = embedding(...) @@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items(): """ litellm.cache = Cache() common_msg = f"hey how's it going {uuid.uuid4()}" - embedding_1 = [common_msg, "I'm doing well"] - embedding_2 = [common_msg, "I'm fine"] + common_msg_2 = f"hey how's it going {uuid.uuid4()}" + embedding_2 = [ + common_msg, + f"I'm fine {uuid.uuid4()}", + common_msg, + common_msg, + common_msg, + ] * 20 + embedding_2 = [ + common_msg, + f"I'm fine {uuid.uuid4()}", + common_msg, + common_msg, + common_msg, + ] * 20 + embedding_3 = [ + common_msg_2, + common_msg_2, + common_msg_2, + common_msg_2, + f"I'm fine {uuid.uuid4()}", + ] * 20 # make sure azure doesn't return cached 'i'm fine' responses embedding_val_1 = await aembedding( model="azure/azure-embedding-model", input=embedding_1, caching=True ) + second_response_start_time = time.time() embedding_val_2 = await aembedding( model="azure/azure-embedding-model", input=embedding_2, caching=True ) - print(f"embedding_val_2: {embedding_val_2}") - if ( - embedding_val_2["data"][0]["embedding"] - != embedding_val_1["data"][0]["embedding"] - ): - print(f"embedding1: {embedding_val_1}") - print(f"embedding2: {embedding_val_2}") - pytest.fail("Error occurred: Embedding caching failed") - if ( - embedding_val_2["data"][1]["embedding"] - == embedding_val_1["data"][1]["embedding"] - ): - print(f"embedding1: {embedding_val_1}") - print(f"embedding2: {embedding_val_2}") - pytest.fail("Error occurred: Embedding caching failed") + if embedding_val_2 is not None: + second_response_end_time = time.time() + second_response_time = second_response_end_time - second_response_start_time + + third_response_start_time = time.time() + embedding_val_3 = await aembedding( + model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True} + ) + if embedding_val_3 is not None: + third_response_end_time = time.time() + third_response_time = third_response_end_time - third_response_start_time + + print(f"second_response_time: {second_response_time}") + print(f"third_response_time: {third_response_time}") + + assert ( + second_response_time < third_response_time - 0.5 + ) # make sure it's actually faster + raise Exception(f"it works {second_response_time} < {third_response_time}") + + +@pytest.mark.asyncio +async def test_redis_cache_basic(): + """ + Init redis client + - write to client + - read from client + """ + litellm.set_verbose = False + + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + ) + + cache_key = litellm.cache.get_cache_key( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"cache_key: {cache_key}") + litellm.cache.add_cache(result=response1, cache_key=cache_key) + print(f"cache key pre async get: {cache_key}") + stored_val = await litellm.cache.async_get_cache( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"stored_val: {stored_val}") + assert stored_val["id"] == response1.id + raise Exception("it worked!") def test_redis_cache_completion(): diff --git a/litellm/utils.py b/litellm/utils.py index 8596b5d16..84a81649a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2214,8 +2214,13 @@ def client(original_function): ) # if caching is false, don't run this final_embedding_cached_response = None + if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) + ( + kwargs.get("caching", None) is None + and kwargs.get("cache", None) is None + and litellm.cache is not None + ) or kwargs.get("caching", False) == True or ( kwargs.get("cache", None) is not None @@ -2234,12 +2239,13 @@ def client(original_function): kwargs["input"], list ): tasks = [] - embedding_kwargs = copy.deepcopy(kwargs) for idx, i in enumerate(kwargs["input"]): - embedding_kwargs["input"] = i + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) tasks.append( - litellm.cache._async_get_cache( - *args, **embedding_kwargs + litellm.cache.async_get_cache( + cache_key=preset_cache_key ) ) cached_result = await asyncio.gather(*tasks) @@ -2445,24 +2451,28 @@ def client(original_function): if isinstance(result, EmbeddingResponse) and isinstance( kwargs["input"], list ): - embedding_kwargs = copy.deepcopy(kwargs) for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) embedding_response = result.data[idx] - embedding_kwargs["input"] = i asyncio.create_task( - litellm.cache._async_add_cache( - embedding_response, *args, **embedding_kwargs + litellm.cache.async_add_cache( + embedding_response, + *args, + cache_key=preset_cache_key, ) ) + # pass else: asyncio.create_task( - litellm.cache._async_add_cache( + litellm.cache.async_add_cache( result.json(), *args, **kwargs ) ) else: asyncio.create_task( - litellm.cache._async_add_cache(result, *args, **kwargs) + litellm.cache.async_add_cache(result, *args, **kwargs) ) # LOG SUCCESS - handle streaming success logging in the _next_ object print_verbose(