working with embeddings

This commit is contained in:
ishaan-jaff 2023-08-28 14:07:51 -07:00
parent 53018e4b23
commit 71ff0c69ab
3 changed files with 56 additions and 18 deletions

View file

@ -32,11 +32,15 @@ class InMemoryCache():
self.cache_dict = {} self.cache_dict = {}
def set_cache(self, key, value): def set_cache(self, key, value):
#print("in set cache for inmem")
self.cache_dict[key] = value self.cache_dict[key] = value
def get_cache(self, key): def get_cache(self, key):
#print("in get cache for inmem")
if key in self.cache_dict: if key in self.cache_dict:
#print("got a cache hit")
return self.cache_dict[key] return self.cache_dict[key]
#print("got a cache miss")
return None return None
class Cache(): class Cache():
@ -46,27 +50,35 @@ class Cache():
if type == "local": if type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
def check_cache(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
prompt = get_prompt(*args, **kwargs)
if prompt is not None:
cache_key = prompt
if "model" in kwargs:
cache_key += kwargs["model"]
elif "input" in kwargs:
cache_key = " ".join(kwargs["input"])
if "model" in kwargs:
cache_key += kwargs["model"]
else:
return None
return cache_key
def get_cache(self, *args, **kwargs):
try: # never block execution try: # never block execution
prompt = get_prompt(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
if prompt != None: # check if messages / prompt exists if cache_key is not None:
if "model" in kwargs: # default to caching with `model + prompt` as key return self.cache.get_cache(cache_key)
cache_key = prompt + kwargs["model"]
return self.cache.get_cache(cache_key)
else:
return self.cache.get_cache(prompt)
except: except:
return None return None
def add_cache(self, result, *args, **kwargs): def add_cache(self, result, *args, **kwargs):
try: try:
prompt = get_prompt(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
if "model" in kwargs: # default to caching with `model + prompt` as key if cache_key is not None:
cache_key = prompt + kwargs["model"]
self.cache.set_cache(cache_key, result) self.cache.set_cache(cache_key, result)
else:
self.cache.set_cache(prompt, result)
except: except:
pass pass
@ -77,5 +89,3 @@ class Cache():

View file

@ -12,6 +12,7 @@ import pytest
import litellm import litellm
from litellm import embedding, completion from litellm import embedding, completion
from litellm.caching import Cache from litellm.caching import Cache
# litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
# comment # comment
@ -83,7 +84,7 @@ def test_gpt_cache():
####### Updated Caching as of Aug 28, 2023 ################### ####### Updated Caching as of Aug 28, 2023 ###################
messages = [{"role": "user", "content": "who is ishaan 5222"}] messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching(): def test_caching_v2():
try: try:
litellm.cache = Cache() litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages) response1 = completion(model="gpt-3.5-turbo", messages=messages)
@ -102,7 +103,7 @@ def test_caching():
# test_caching() # test_caching()
def test_caching_with_models(): def test_caching_with_models_v2():
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}] messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
litellm.cache = Cache() litellm.cache = Cache()
print("test2 for caching") print("test2 for caching")
@ -123,6 +124,33 @@ def test_caching_with_models():
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
embedding_large_text = """
small text
""" * 5
# test_caching_with_models() # test_caching_with_models()
def test_embedding_caching():
import time
litellm.cache = Cache()
text_to_embed = [embedding_large_text]
start_time = time.time()
embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed)
end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds")
time.sleep(1)
start_time = time.time()
embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed)
end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None
if embedding2 != embedding1:
print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
# test_embedding_caching()

View file

@ -403,7 +403,7 @@ def client(original_function):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and ( if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and (
cached_result := litellm.cache.check_cache(*args, **kwargs) cached_result := litellm.cache.get_cache(*args, **kwargs)
) is not None: ) is not None:
result = cached_result result = cached_result
return result return result