forked from phoenix/litellm-mirror
(feat )add semantic cache
This commit is contained in:
parent
646764f1d4
commit
d4a799a3ca
2 changed files with 124 additions and 3 deletions
|
@ -83,7 +83,6 @@ class InMemoryCache(BaseCache):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -217,7 +216,6 @@ class RedisCache(BaseCache):
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.redis_client.flushall()
|
self.redis_client.flushall()
|
||||||
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -225,6 +223,102 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client.delete(key)
|
self.redis_client.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisSemanticCache(RedisCache):
|
||||||
|
def __init__(self, host, port, password, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# from redis.commands.search.field import TagField, TextField, NumericField, VectorField
|
||||||
|
# from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||||
|
# from redis.commands.search.query import Query
|
||||||
|
|
||||||
|
# INDEX_NAME = 'idx:litellm_completion_response_vss'
|
||||||
|
# DOC_PREFIX = 'bikes:'
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# # check to see if index exists
|
||||||
|
# client.ft(INDEX_NAME).info()
|
||||||
|
# print('Index already exists!')
|
||||||
|
# except:
|
||||||
|
# # schema
|
||||||
|
# schema = (
|
||||||
|
# TextField('$.model', no_stem=True, as_name='model'),
|
||||||
|
# TextField('$.brand', no_stem=True, as_name='brand'),
|
||||||
|
# NumericField('$.price', as_name='price'),
|
||||||
|
# TagField('$.type', as_name='type'),
|
||||||
|
# TextField('$.description', as_name='description'),
|
||||||
|
# VectorField('$.description_embeddings',
|
||||||
|
# 'FLAT', {
|
||||||
|
# 'TYPE': 'FLOAT32',
|
||||||
|
# 'DIM': VECTOR_DIMENSION,
|
||||||
|
# 'DISTANCE_METRIC': 'COSINE',
|
||||||
|
# }, as_name='vector'
|
||||||
|
# ),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # index Definition
|
||||||
|
# definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.JSON)
|
||||||
|
|
||||||
|
# # create Index
|
||||||
|
# client.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
|
||||||
|
|
||||||
|
def set_cache(self, key, value, **kwargs):
|
||||||
|
ttl = kwargs.get("ttl", None)
|
||||||
|
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
||||||
|
try:
|
||||||
|
# get text response
|
||||||
|
# print("in redis semantic cache: value: ", value)
|
||||||
|
llm_response = value["response"]
|
||||||
|
|
||||||
|
# if llm_response is a string, convert it to a dictionary
|
||||||
|
if isinstance(llm_response, str):
|
||||||
|
llm_response = json.loads(llm_response)
|
||||||
|
|
||||||
|
# print("converted llm_response: ", llm_response)
|
||||||
|
response = llm_response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# create embedding response
|
||||||
|
|
||||||
|
embedding_response = litellm.embedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input=response,
|
||||||
|
cache={"no-store": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
raw_embedding_dimension = len(raw_embedding)
|
||||||
|
|
||||||
|
# print("embedding: ", raw_embedding)
|
||||||
|
key = "litellm-semantic:" + key
|
||||||
|
self.redis_client.json().set(
|
||||||
|
name=key,
|
||||||
|
path="$",
|
||||||
|
obj=json.dumps(
|
||||||
|
{
|
||||||
|
"response": response,
|
||||||
|
"embedding": raw_embedding,
|
||||||
|
"dimension": raw_embedding_dimension,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
stored_redis_value = self.redis_client.json().get(name=key)
|
||||||
|
|
||||||
|
# print("Stored Redis Value: ", stored_redis_value)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# print("Error occurred: ", e)
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||||
|
|
||||||
|
def get_cache(self, key, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class S3Cache(BaseCache):
|
class S3Cache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -429,7 +523,7 @@ class DualCache(BaseCache):
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
port: Optional[str] = None,
|
port: Optional[str] = None,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
|
@ -468,6 +562,8 @@ class Cache:
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||||
|
elif type == "redis-semantic":
|
||||||
|
self.cache = RedisSemanticCache(host, port, password, **kwargs)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
elif type == "s3":
|
elif type == "s3":
|
||||||
|
|
|
@ -987,3 +987,28 @@ def test_cache_context_managers():
|
||||||
|
|
||||||
|
|
||||||
# test_cache_context_managers()
|
# test_cache_context_managers()
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_semantic_cache_completion():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||||
|
]
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis-semantic",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
)
|
||||||
|
print("test2 for Redis Caching - non streaming")
|
||||||
|
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=20)
|
||||||
|
# response2 = completion(
|
||||||
|
# model="gpt-3.5-turbo", messages=messages,max_tokens=20
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# test_redis_cache_completion()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue