working with embeddings

This commit is contained in:
ishaan-jaff 2023-08-28 14:07:51 -07:00
parent 53018e4b23
commit 71ff0c69ab
3 changed files with 56 additions and 18 deletions

View file

@ -32,11 +32,15 @@ class InMemoryCache():
self.cache_dict = {}
def set_cache(self, key, value):
#print("in set cache for inmem")
self.cache_dict[key] = value
def get_cache(self, key):
#print("in get cache for inmem")
if key in self.cache_dict:
#print("got a cache hit")
return self.cache_dict[key]
#print("got a cache miss")
return None
class Cache():
@ -46,27 +50,35 @@ class Cache():
if type == "local":
self.cache = InMemoryCache()
def check_cache(self, *args, **kwargs):
def get_cache_key(self, *args, **kwargs):
prompt = get_prompt(*args, **kwargs)
if prompt is not None:
cache_key = prompt
if "model" in kwargs:
cache_key += kwargs["model"]
elif "input" in kwargs:
cache_key = " ".join(kwargs["input"])
if "model" in kwargs:
cache_key += kwargs["model"]
else:
return None
return cache_key
def get_cache(self, *args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if prompt != None: # check if messages / prompt exists
if "model" in kwargs: # default to caching with `model + prompt` as key
cache_key = prompt + kwargs["model"]
return self.cache.get_cache(cache_key)
else:
return self.cache.get_cache(prompt)
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
return self.cache.get_cache(cache_key)
except:
return None
def add_cache(self, result, *args, **kwargs):
try:
prompt = get_prompt(*args, **kwargs)
if "model" in kwargs: # default to caching with `model + prompt` as key
cache_key = prompt + kwargs["model"]
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
self.cache.set_cache(cache_key, result)
else:
self.cache.set_cache(prompt, result)
except:
pass
@ -77,5 +89,3 @@ class Cache():