forked from phoenix/litellm-mirror
(feat )add semantic cache
This commit is contained in:
parent
646764f1d4
commit
d4a799a3ca
2 changed files with 124 additions and 3 deletions
|
@ -83,7 +83,6 @@ class InMemoryCache(BaseCache):
|
|||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
@ -217,7 +216,6 @@ class RedisCache(BaseCache):
|
|||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
@ -225,6 +223,102 @@ class RedisCache(BaseCache):
|
|||
self.redis_client.delete(key)
|
||||
|
||||
|
||||
class RedisSemanticCache(RedisCache):
|
||||
def __init__(self, host, port, password, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# from redis.commands.search.field import TagField, TextField, NumericField, VectorField
|
||||
# from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
# from redis.commands.search.query import Query
|
||||
|
||||
# INDEX_NAME = 'idx:litellm_completion_response_vss'
|
||||
# DOC_PREFIX = 'bikes:'
|
||||
|
||||
# try:
|
||||
# # check to see if index exists
|
||||
# client.ft(INDEX_NAME).info()
|
||||
# print('Index already exists!')
|
||||
# except:
|
||||
# # schema
|
||||
# schema = (
|
||||
# TextField('$.model', no_stem=True, as_name='model'),
|
||||
# TextField('$.brand', no_stem=True, as_name='brand'),
|
||||
# NumericField('$.price', as_name='price'),
|
||||
# TagField('$.type', as_name='type'),
|
||||
# TextField('$.description', as_name='description'),
|
||||
# VectorField('$.description_embeddings',
|
||||
# 'FLAT', {
|
||||
# 'TYPE': 'FLOAT32',
|
||||
# 'DIM': VECTOR_DIMENSION,
|
||||
# 'DISTANCE_METRIC': 'COSINE',
|
||||
# }, as_name='vector'
|
||||
# ),
|
||||
# )
|
||||
|
||||
# # index Definition
|
||||
# definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.JSON)
|
||||
|
||||
# # create Index
|
||||
# client.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
||||
try:
|
||||
# get text response
|
||||
# print("in redis semantic cache: value: ", value)
|
||||
llm_response = value["response"]
|
||||
|
||||
# if llm_response is a string, convert it to a dictionary
|
||||
if isinstance(llm_response, str):
|
||||
llm_response = json.loads(llm_response)
|
||||
|
||||
# print("converted llm_response: ", llm_response)
|
||||
response = llm_response["choices"][0]["message"]["content"]
|
||||
|
||||
# create embedding response
|
||||
|
||||
embedding_response = litellm.embedding(
|
||||
model="text-embedding-ada-002",
|
||||
input=response,
|
||||
cache={"no-store": True},
|
||||
)
|
||||
|
||||
raw_embedding = embedding_response["data"][0]["embedding"]
|
||||
raw_embedding_dimension = len(raw_embedding)
|
||||
|
||||
# print("embedding: ", raw_embedding)
|
||||
key = "litellm-semantic:" + key
|
||||
self.redis_client.json().set(
|
||||
name=key,
|
||||
path="$",
|
||||
obj=json.dumps(
|
||||
{
|
||||
"response": response,
|
||||
"embedding": raw_embedding,
|
||||
"dimension": raw_embedding_dimension,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
stored_redis_value = self.redis_client.json().get(name=key)
|
||||
|
||||
# print("Stored Redis Value: ", stored_redis_value)
|
||||
|
||||
except Exception as e:
|
||||
# print("Error occurred: ", e)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
pass
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
|
@ -429,7 +523,7 @@ class DualCache(BaseCache):
|
|||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -468,6 +562,8 @@ class Cache:
|
|||
"""
|
||||
if type == "redis":
|
||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||
elif type == "redis-semantic":
|
||||
self.cache = RedisSemanticCache(host, port, password, **kwargs)
|
||||
elif type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
elif type == "s3":
|
||||
|
|
|
@ -987,3 +987,28 @@ def test_cache_context_managers():
|
|||
|
||||
|
||||
# test_cache_context_managers()
|
||||
|
||||
|
||||
def test_redis_semantic_cache_completion():
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis-semantic",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
print("test2 for Redis Caching - non streaming")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=20)
|
||||
# response2 = completion(
|
||||
# model="gpt-3.5-turbo", messages=messages,max_tokens=20
|
||||
# )
|
||||
|
||||
|
||||
# test_redis_cache_completion()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue