mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(feat) RedisSemanticCache - async
This commit is contained in:
parent
05d64acdd3
commit
97fbfc07b4
1 changed files with 106 additions and 6 deletions
|
@ -231,6 +231,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
password=None,
|
password=None,
|
||||||
redis_url=None,
|
redis_url=None,
|
||||||
similarity_threshold=None,
|
similarity_threshold=None,
|
||||||
|
use_async=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from redisvl.index import SearchIndex
|
from redisvl.index import SearchIndex
|
||||||
|
@ -262,14 +263,19 @@ class RedisSemanticCache(BaseCache):
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
self.index = SearchIndex.from_dict(schema)
|
|
||||||
if redis_url is None:
|
if redis_url is None:
|
||||||
# if no url passed, check if host, port and password are passed, if not raise an Exception
|
# if no url passed, check if host, port and password are passed, if not raise an Exception
|
||||||
if host is None or port is None or password is None:
|
if host is None or port is None or password is None:
|
||||||
raise Exception(f"Redis host, port, and password must be provided")
|
raise Exception(f"Redis host, port, and password must be provided")
|
||||||
redis_url = "redis://:" + password + "@" + host + ":" + port
|
redis_url = "redis://:" + password + "@" + host + ":" + port
|
||||||
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
||||||
self.index.connect(redis_url=redis_url)
|
if use_async == False:
|
||||||
|
self.index = SearchIndex.from_dict(schema)
|
||||||
|
self.index.connect(redis_url=redis_url)
|
||||||
|
elif use_async == True:
|
||||||
|
schema["index"]["name"] = "litellm_semantic_cache_index_async"
|
||||||
|
self.index = SearchIndex.from_dict(schema)
|
||||||
|
self.index.connect(redis_url=redis_url, use_async=True)
|
||||||
try:
|
try:
|
||||||
self.index.create(overwrite=False) # don't overwrite existing index
|
self.index.create(overwrite=False) # don't overwrite existing index
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -327,10 +333,10 @@ class RedisSemanticCache(BaseCache):
|
||||||
# Add more data
|
# Add more data
|
||||||
keys = self.index.load(new_data)
|
keys = self.index.load(new_data)
|
||||||
|
|
||||||
pass
|
return
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
print_verbose(f"redis semantic-cache get_cache, kwargs: {kwargs}")
|
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
from redisvl.query import VectorQuery
|
from redisvl.query import VectorQuery
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -360,6 +366,11 @@ class RedisSemanticCache(BaseCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.index.query(query)
|
results = self.index.query(query)
|
||||||
|
if results == None:
|
||||||
|
return None
|
||||||
|
if isinstance(results, list):
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
vector_distance = results[0]["vector_distance"]
|
vector_distance = results[0]["vector_distance"]
|
||||||
vector_distance = float(vector_distance)
|
vector_distance = float(vector_distance)
|
||||||
|
@ -384,9 +395,93 @@ class RedisSemanticCache(BaseCache):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
pass
|
import numpy as np
|
||||||
|
|
||||||
|
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
|
# get the prompt
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
# create an embedding for prompt
|
||||||
|
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
# make the embedding a numpy array, convert to bytes
|
||||||
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
||||||
|
value = str(value)
|
||||||
|
assert isinstance(value, str)
|
||||||
|
|
||||||
|
new_data = [
|
||||||
|
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add more data
|
||||||
|
keys = await self.index.aload(new_data)
|
||||||
|
return
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# query
|
||||||
|
|
||||||
|
# get the messages
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
query = VectorQuery(
|
||||||
|
vector=embedding,
|
||||||
|
vector_field_name="litellm_embedding",
|
||||||
|
return_fields=["response", "prompt", "vector_distance"],
|
||||||
|
)
|
||||||
|
results = await self.index.aquery(query)
|
||||||
|
if results == None:
|
||||||
|
return None
|
||||||
|
if isinstance(results, list):
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
vector_distance = results[0]["vector_distance"]
|
||||||
|
vector_distance = float(vector_distance)
|
||||||
|
similarity = 1 - vector_distance
|
||||||
|
cached_prompt = results[0]["prompt"]
|
||||||
|
|
||||||
|
# check similarity, if more than self.similarity_threshold, return results
|
||||||
|
print_verbose(
|
||||||
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
if similarity > self.similarity_threshold:
|
||||||
|
# cache hit !
|
||||||
|
cached_value = results[0]["response"]
|
||||||
|
print_verbose(
|
||||||
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
return self._get_cache_logic(cached_response=cached_value)
|
||||||
|
else:
|
||||||
|
# cache miss !
|
||||||
|
return None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -612,6 +707,7 @@ class Cache:
|
||||||
s3_aws_secret_access_key: Optional[str] = None,
|
s3_aws_secret_access_key: Optional[str] = None,
|
||||||
s3_aws_session_token: Optional[str] = None,
|
s3_aws_session_token: Optional[str] = None,
|
||||||
s3_config: Optional[Any] = None,
|
s3_config: Optional[Any] = None,
|
||||||
|
redis_semantic_cache_use_async=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -641,6 +737,7 @@ class Cache:
|
||||||
port,
|
port,
|
||||||
password,
|
password,
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
|
use_async=redis_semantic_cache_use_async,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
|
@ -847,6 +944,7 @@ class Cache:
|
||||||
Used for embedding calls in async wrapper
|
Used for embedding calls in async wrapper
|
||||||
"""
|
"""
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
cache_key = kwargs["cache_key"]
|
cache_key = kwargs["cache_key"]
|
||||||
else:
|
else:
|
||||||
|
@ -856,7 +954,9 @@ class Cache:
|
||||||
max_age = cache_control_args.get(
|
max_age = cache_control_args.get(
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
)
|
)
|
||||||
cached_result = await self.cache.async_get_cache(cache_key)
|
cached_result = await self.cache.async_get_cache(
|
||||||
|
cache_key, messages=messages
|
||||||
|
)
|
||||||
return self._get_cache_logic(
|
return self._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue