From d4a799a3ca97ce84f8491a8b372eaea7651292e5 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 5 Feb 2024 12:28:21 -0800 Subject: [PATCH] (feat )add semantic cache --- litellm/caching.py | 102 +++++++++++++++++++++++++++++++++- litellm/tests/test_caching.py | 25 +++++++++ 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index d0721fe9a..e1ef95dc3 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -83,7 +83,6 @@ class InMemoryCache(BaseCache): self.cache_dict.clear() self.ttl_dict.clear() - async def disconnect(self): pass @@ -217,7 +216,6 @@ class RedisCache(BaseCache): def flush_cache(self): self.redis_client.flushall() - async def disconnect(self): pass @@ -225,6 +223,102 @@ class RedisCache(BaseCache): self.redis_client.delete(key) +class RedisSemanticCache(RedisCache): + def __init__(self, host, port, password, **kwargs): + super().__init__() + + # from redis.commands.search.field import TagField, TextField, NumericField, VectorField + # from redis.commands.search.indexDefinition import IndexDefinition, IndexType + # from redis.commands.search.query import Query + + # INDEX_NAME = 'idx:litellm_completion_response_vss' + # DOC_PREFIX = 'bikes:' + + # try: + # # check to see if index exists + # client.ft(INDEX_NAME).info() + # print('Index already exists!') + # except: + # # schema + # schema = ( + # TextField('$.model', no_stem=True, as_name='model'), + # TextField('$.brand', no_stem=True, as_name='brand'), + # NumericField('$.price', as_name='price'), + # TagField('$.type', as_name='type'), + # TextField('$.description', as_name='description'), + # VectorField('$.description_embeddings', + # 'FLAT', { + # 'TYPE': 'FLOAT32', + # 'DIM': VECTOR_DIMENSION, + # 'DISTANCE_METRIC': 'COSINE', + # }, as_name='vector' + # ), + # ) + + # # index Definition + # definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.JSON) + + # # create Index + # client.ft(INDEX_NAME).create_index(fields=schema, definition=definition) + + def set_cache(self, key, value, **kwargs): + ttl = kwargs.get("ttl", None) + print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}") + try: + # get text response + # print("in redis semantic cache: value: ", value) + llm_response = value["response"] + + # if llm_response is a string, convert it to a dictionary + if isinstance(llm_response, str): + llm_response = json.loads(llm_response) + + # print("converted llm_response: ", llm_response) + response = llm_response["choices"][0]["message"]["content"] + + # create embedding response + + embedding_response = litellm.embedding( + model="text-embedding-ada-002", + input=response, + cache={"no-store": True}, + ) + + raw_embedding = embedding_response["data"][0]["embedding"] + raw_embedding_dimension = len(raw_embedding) + + # print("embedding: ", raw_embedding) + key = "litellm-semantic:" + key + self.redis_client.json().set( + name=key, + path="$", + obj=json.dumps( + { + "response": response, + "embedding": raw_embedding, + "dimension": raw_embedding_dimension, + } + ), + ) + + stored_redis_value = self.redis_client.json().get(name=key) + + # print("Stored Redis Value: ", stored_redis_value) + + except Exception as e: + # print("Error occurred: ", e) + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + + def get_cache(self, key, **kwargs): + pass + + async def async_set_cache(self, key, value, **kwargs): + pass + + async def async_get_cache(self, key, **kwargs): + pass + class S3Cache(BaseCache): def __init__( @@ -429,7 +523,7 @@ class DualCache(BaseCache): class Cache: def __init__( self, - type: Optional[Literal["local", "redis", "s3"]] = "local", + type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local", host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, @@ -468,6 +562,8 @@ class Cache: """ if type == "redis": self.cache: BaseCache = RedisCache(host, port, password, **kwargs) + elif type == "redis-semantic": + self.cache = RedisSemanticCache(host, port, password, **kwargs) elif type == "local": self.cache = InMemoryCache() elif type == "s3": diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 468ab6f80..32904ab78 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -987,3 +987,28 @@ def test_cache_context_managers(): # test_cache_context_managers() + + +def test_redis_semantic_cache_completion(): + litellm.set_verbose = False + + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + litellm.cache = Cache( + type="redis-semantic", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + print("test2 for Redis Caching - non streaming") + response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=20) + # response2 = completion( + # model="gpt-3.5-turbo", messages=messages,max_tokens=20 + # ) + + +# test_redis_cache_completion()