diff --git a/litellm/caching.py b/litellm/caching.py index c7ed6a258..d24a301e2 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -32,11 +32,15 @@ class InMemoryCache(): self.cache_dict = {} def set_cache(self, key, value): + #print("in set cache for inmem") self.cache_dict[key] = value def get_cache(self, key): + #print("in get cache for inmem") if key in self.cache_dict: + #print("got a cache hit") return self.cache_dict[key] + #print("got a cache miss") return None class Cache(): @@ -46,27 +50,35 @@ class Cache(): if type == "local": self.cache = InMemoryCache() - def check_cache(self, *args, **kwargs): + def get_cache_key(self, *args, **kwargs): + prompt = get_prompt(*args, **kwargs) + if prompt is not None: + cache_key = prompt + if "model" in kwargs: + cache_key += kwargs["model"] + elif "input" in kwargs: + cache_key = " ".join(kwargs["input"]) + if "model" in kwargs: + cache_key += kwargs["model"] + else: + return None + return cache_key + + def get_cache(self, *args, **kwargs): try: # never block execution - prompt = get_prompt(*args, **kwargs) - if prompt != None: # check if messages / prompt exists - if "model" in kwargs: # default to caching with `model + prompt` as key - cache_key = prompt + kwargs["model"] - return self.cache.get_cache(cache_key) - else: - return self.cache.get_cache(prompt) + cache_key = self.get_cache_key(*args, **kwargs) + if cache_key is not None: + return self.cache.get_cache(cache_key) except: return None def add_cache(self, result, *args, **kwargs): try: - prompt = get_prompt(*args, **kwargs) - if "model" in kwargs: # default to caching with `model + prompt` as key - cache_key = prompt + kwargs["model"] + cache_key = self.get_cache_key(*args, **kwargs) + if cache_key is not None: self.cache.set_cache(cache_key, result) - else: - self.cache.set_cache(prompt, result) except: + pass @@ -77,5 +89,3 @@ class Cache(): - - diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index c4d89bc98..1845f3b47 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -12,6 +12,7 @@ import pytest import litellm from litellm import embedding, completion from litellm.caching import Cache +# litellm.set_verbose=True messages = [{"role": "user", "content": "who is ishaan Github? "}] # comment @@ -83,7 +84,7 @@ def test_gpt_cache(): ####### Updated Caching as of Aug 28, 2023 ################### messages = [{"role": "user", "content": "who is ishaan 5222"}] -def test_caching(): +def test_caching_v2(): try: litellm.cache = Cache() response1 = completion(model="gpt-3.5-turbo", messages=messages) @@ -102,7 +103,7 @@ def test_caching(): # test_caching() -def test_caching_with_models(): +def test_caching_with_models_v2(): messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}] litellm.cache = Cache() print("test2 for caching") @@ -123,6 +124,33 @@ def test_caching_with_models(): print(f"response2: {response2}") pytest.fail(f"Error occurred:") + +embedding_large_text = """ +small text +""" * 5 + # test_caching_with_models() +def test_embedding_caching(): + import time + litellm.cache = Cache() + text_to_embed = [embedding_large_text] + start_time = time.time() + embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed) + end_time = time.time() + print(f"Embedding 1 response time: {end_time - start_time} seconds") + + time.sleep(1) + start_time = time.time() + embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed) + end_time = time.time() + print(f"Embedding 2 response time: {end_time - start_time} seconds") + + litellm.cache = None + if embedding2 != embedding1: + print(f"embedding1: {embedding1}") + print(f"embedding2: {embedding2}") + pytest.fail("Error occurred: Embedding caching failed") + +# test_embedding_caching() diff --git a/litellm/utils.py b/litellm/utils.py index 7746a82d8..8e9b1dd0e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -403,7 +403,7 @@ def client(original_function): start_time = datetime.datetime.now() # [OPTIONAL] CHECK CACHE if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and ( - cached_result := litellm.cache.check_cache(*args, **kwargs) + cached_result := litellm.cache.get_cache(*args, **kwargs) ) is not None: result = cached_result return result