(feat) RedisSemanticCache - async

This commit is contained in:
ishaan-jaff 2024-02-06 08:13:12 -08:00
parent ccc94128d3
commit 76def20ffe

View file

@ -231,6 +231,7 @@ class RedisSemanticCache(BaseCache):
password=None, password=None,
redis_url=None, redis_url=None,
similarity_threshold=None, similarity_threshold=None,
use_async=False,
**kwargs, **kwargs,
): ):
from redisvl.index import SearchIndex from redisvl.index import SearchIndex
@ -262,14 +263,19 @@ class RedisSemanticCache(BaseCache):
], ],
}, },
} }
self.index = SearchIndex.from_dict(schema)
if redis_url is None: if redis_url is None:
# if no url passed, check if host, port and password are passed, if not raise an Exception # if no url passed, check if host, port and password are passed, if not raise an Exception
if host is None or port is None or password is None: if host is None or port is None or password is None:
raise Exception(f"Redis host, port, and password must be provided") raise Exception(f"Redis host, port, and password must be provided")
redis_url = "redis://:" + password + "@" + host + ":" + port redis_url = "redis://:" + password + "@" + host + ":" + port
print_verbose(f"redis semantic-cache redis_url: {redis_url}") print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async == False:
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url) self.index.connect(redis_url=redis_url)
elif use_async == True:
schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True)
try: try:
self.index.create(overwrite=False) # don't overwrite existing index self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e: except Exception as e:
@ -327,10 +333,10 @@ class RedisSemanticCache(BaseCache):
# Add more data # Add more data
keys = self.index.load(new_data) keys = self.index.load(new_data)
pass return
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
print_verbose(f"redis semantic-cache get_cache, kwargs: {kwargs}") print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery from redisvl.query import VectorQuery
import numpy as np import numpy as np
@ -360,6 +366,11 @@ class RedisSemanticCache(BaseCache):
) )
results = self.index.query(query) results = self.index.query(query)
if results == None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
vector_distance = results[0]["vector_distance"] vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance) vector_distance = float(vector_distance)
@ -384,9 +395,93 @@ class RedisSemanticCache(BaseCache):
pass pass
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
pass import numpy as np
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
embedding_response = await litellm.aembedding(
model="text-embedding-ada-002",
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)
new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# Add more data
keys = await self.index.aload(new_data)
return
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np
# query
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# convert to embedding
embedding_response = await litellm.aembedding(
model="text-embedding-ada-002",
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
)
results = await self.index.aquery(query)
if results == None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass pass
@ -612,6 +707,7 @@ class Cache:
s3_aws_secret_access_key: Optional[str] = None, s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None, s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None, s3_config: Optional[Any] = None,
redis_semantic_cache_use_async=False,
**kwargs, **kwargs,
): ):
""" """
@ -641,6 +737,7 @@ class Cache:
port, port,
password, password,
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
use_async=redis_semantic_cache_use_async,
**kwargs, **kwargs,
) )
elif type == "local": elif type == "local":
@ -847,6 +944,7 @@ class Cache:
Used for embedding calls in async wrapper Used for embedding calls in async wrapper
""" """
try: # never block execution try: # never block execution
messages = kwargs.get("messages", [])
if "cache_key" in kwargs: if "cache_key" in kwargs:
cache_key = kwargs["cache_key"] cache_key = kwargs["cache_key"]
else: else:
@ -856,7 +954,9 @@ class Cache:
max_age = cache_control_args.get( max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf")) "s-max-age", cache_control_args.get("s-maxage", float("inf"))
) )
cached_result = await self.cache.async_get_cache(cache_key) cached_result = await self.cache.async_get_cache(
cache_key, messages=messages
)
return self._get_cache_logic( return self._get_cache_logic(
cached_result=cached_result, max_age=max_age cached_result=cached_result, max_age=max_age
) )