(feat )add semantic cache

This commit is contained in:
ishaan-jaff 2024-02-05 12:28:21 -08:00
parent 646764f1d4
commit d4a799a3ca
2 changed files with 124 additions and 3 deletions

View file

@ -83,7 +83,6 @@ class InMemoryCache(BaseCache):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
@ -217,7 +216,6 @@ class RedisCache(BaseCache):
def flush_cache(self):
self.redis_client.flushall()
async def disconnect(self):
pass
@ -225,6 +223,102 @@ class RedisCache(BaseCache):
self.redis_client.delete(key)
class RedisSemanticCache(RedisCache):
def __init__(self, host, port, password, **kwargs):
super().__init__()
# from redis.commands.search.field import TagField, TextField, NumericField, VectorField
# from redis.commands.search.indexDefinition import IndexDefinition, IndexType
# from redis.commands.search.query import Query
# INDEX_NAME = 'idx:litellm_completion_response_vss'
# DOC_PREFIX = 'bikes:'
# try:
# # check to see if index exists
# client.ft(INDEX_NAME).info()
# print('Index already exists!')
# except:
# # schema
# schema = (
# TextField('$.model', no_stem=True, as_name='model'),
# TextField('$.brand', no_stem=True, as_name='brand'),
# NumericField('$.price', as_name='price'),
# TagField('$.type', as_name='type'),
# TextField('$.description', as_name='description'),
# VectorField('$.description_embeddings',
# 'FLAT', {
# 'TYPE': 'FLOAT32',
# 'DIM': VECTOR_DIMENSION,
# 'DISTANCE_METRIC': 'COSINE',
# }, as_name='vector'
# ),
# )
# # index Definition
# definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.JSON)
# # create Index
# client.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
try:
# get text response
# print("in redis semantic cache: value: ", value)
llm_response = value["response"]
# if llm_response is a string, convert it to a dictionary
if isinstance(llm_response, str):
llm_response = json.loads(llm_response)
# print("converted llm_response: ", llm_response)
response = llm_response["choices"][0]["message"]["content"]
# create embedding response
embedding_response = litellm.embedding(
model="text-embedding-ada-002",
input=response,
cache={"no-store": True},
)
raw_embedding = embedding_response["data"][0]["embedding"]
raw_embedding_dimension = len(raw_embedding)
# print("embedding: ", raw_embedding)
key = "litellm-semantic:" + key
self.redis_client.json().set(
name=key,
path="$",
obj=json.dumps(
{
"response": response,
"embedding": raw_embedding,
"dimension": raw_embedding_dimension,
}
),
)
stored_redis_value = self.redis_client.json().get(name=key)
# print("Stored Redis Value: ", stored_redis_value)
except Exception as e:
# print("Error occurred: ", e)
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def get_cache(self, key, **kwargs):
pass
async def async_set_cache(self, key, value, **kwargs):
pass
async def async_get_cache(self, key, **kwargs):
pass
class S3Cache(BaseCache):
def __init__(
@ -429,7 +523,7 @@ class DualCache(BaseCache):
class Cache:
def __init__(
self,
type: Optional[Literal["local", "redis", "s3"]] = "local",
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -468,6 +562,8 @@ class Cache:
"""
if type == "redis":
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
elif type == "redis-semantic":
self.cache = RedisSemanticCache(host, port, password, **kwargs)
elif type == "local":
self.cache = InMemoryCache()
elif type == "s3":

View file

@ -987,3 +987,28 @@ def test_cache_context_managers():
# test_cache_context_managers()
def test_redis_semantic_cache_completion():
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis-semantic",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test2 for Redis Caching - non streaming")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=20)
# response2 = completion(
# model="gpt-3.5-turbo", messages=messages,max_tokens=20
# )
# test_redis_cache_completion()