From 97fbfc07b4fe13f43668eaf89f62d5d95afb28e3 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 6 Feb 2024 08:13:12 -0800 Subject: [PATCH] (feat) RedisSemanticCache - async --- litellm/caching.py | 112 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 106 insertions(+), 6 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 877f935fab..ad37f2077c 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -231,6 +231,7 @@ class RedisSemanticCache(BaseCache): password=None, redis_url=None, similarity_threshold=None, + use_async=False, **kwargs, ): from redisvl.index import SearchIndex @@ -262,14 +263,19 @@ class RedisSemanticCache(BaseCache): ], }, } - self.index = SearchIndex.from_dict(schema) if redis_url is None: # if no url passed, check if host, port and password are passed, if not raise an Exception if host is None or port is None or password is None: raise Exception(f"Redis host, port, and password must be provided") redis_url = "redis://:" + password + "@" + host + ":" + port print_verbose(f"redis semantic-cache redis_url: {redis_url}") - self.index.connect(redis_url=redis_url) + if use_async == False: + self.index = SearchIndex.from_dict(schema) + self.index.connect(redis_url=redis_url) + elif use_async == True: + schema["index"]["name"] = "litellm_semantic_cache_index_async" + self.index = SearchIndex.from_dict(schema) + self.index.connect(redis_url=redis_url, use_async=True) try: self.index.create(overwrite=False) # don't overwrite existing index except Exception as e: @@ -327,10 +333,10 @@ class RedisSemanticCache(BaseCache): # Add more data keys = self.index.load(new_data) - pass + return def get_cache(self, key, **kwargs): - print_verbose(f"redis semantic-cache get_cache, kwargs: {kwargs}") + print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") from redisvl.query import VectorQuery import numpy as np @@ -360,6 +366,11 @@ class RedisSemanticCache(BaseCache): ) results = self.index.query(query) + if results == None: + return None + if isinstance(results, list): + if len(results) == 0: + return None vector_distance = results[0]["vector_distance"] vector_distance = float(vector_distance) @@ -384,9 +395,93 @@ class RedisSemanticCache(BaseCache): pass async def async_set_cache(self, key, value, **kwargs): - pass + import numpy as np + + print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + # create an embedding for prompt + + embedding_response = await litellm.aembedding( + model="text-embedding-ada-002", + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + # make the embedding a numpy array, convert to bytes + embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() + value = str(value) + assert isinstance(value, str) + + new_data = [ + {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} + ] + + # Add more data + keys = await self.index.aload(new_data) + return async def async_get_cache(self, key, **kwargs): + print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") + from redisvl.query import VectorQuery + import numpy as np + + # query + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # convert to embedding + embedding_response = await litellm.aembedding( + model="text-embedding-ada-002", + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + query = VectorQuery( + vector=embedding, + vector_field_name="litellm_embedding", + return_fields=["response", "prompt", "vector_distance"], + ) + results = await self.index.aquery(query) + if results == None: + return None + if isinstance(results, list): + if len(results) == 0: + return None + + vector_distance = results[0]["vector_distance"] + vector_distance = float(vector_distance) + similarity = 1 - vector_distance + cached_prompt = results[0]["prompt"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + if similarity > self.similarity_threshold: + # cache hit ! + cached_value = results[0]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None pass @@ -612,6 +707,7 @@ class Cache: s3_aws_secret_access_key: Optional[str] = None, s3_aws_session_token: Optional[str] = None, s3_config: Optional[Any] = None, + redis_semantic_cache_use_async=False, **kwargs, ): """ @@ -641,6 +737,7 @@ class Cache: port, password, similarity_threshold=similarity_threshold, + use_async=redis_semantic_cache_use_async, **kwargs, ) elif type == "local": @@ -847,6 +944,7 @@ class Cache: Used for embedding calls in async wrapper """ try: # never block execution + messages = kwargs.get("messages", []) if "cache_key" in kwargs: cache_key = kwargs["cache_key"] else: @@ -856,7 +954,9 @@ class Cache: max_age = cache_control_args.get( "s-max-age", cache_control_args.get("s-maxage", float("inf")) ) - cached_result = await self.cache.async_get_cache(cache_key) + cached_result = await self.cache.async_get_cache( + cache_key, messages=messages + ) return self._get_cache_logic( cached_result=cached_result, max_age=max_age )