mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
(feat) working - sync semantic caching
This commit is contained in:
parent
168a2f7806
commit
80865f93b8
1 changed files with 152 additions and 75 deletions
|
@ -223,94 +223,161 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client.delete(key)
|
self.redis_client.delete(key)
|
||||||
|
|
||||||
|
|
||||||
class RedisSemanticCache(RedisCache):
|
class RedisSemanticCache(BaseCache):
|
||||||
def __init__(self, host, port, password, **kwargs):
|
def __init__(
|
||||||
super().__init__()
|
self,
|
||||||
|
host=None,
|
||||||
|
port=None,
|
||||||
|
password=None,
|
||||||
|
redis_url=None,
|
||||||
|
similarity_threshold=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
from redisvl.index import SearchIndex
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
|
||||||
# from redis.commands.search.field import TagField, TextField, NumericField, VectorField
|
print_verbose(
|
||||||
# from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
|
||||||
# from redis.commands.search.query import Query
|
)
|
||||||
|
if similarity_threshold is None:
|
||||||
|
raise Exception("similarity_threshold must be provided, passed None")
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
schema = {
|
||||||
|
"index": {
|
||||||
|
"name": "litellm_semantic_cache_index",
|
||||||
|
"prefix": "litellm",
|
||||||
|
"storage_type": "hash",
|
||||||
|
},
|
||||||
|
"fields": {
|
||||||
|
"text": [{"name": "response"}],
|
||||||
|
"text": [{"name": "prompt"}],
|
||||||
|
"vector": [
|
||||||
|
{
|
||||||
|
"name": "litellm_embedding",
|
||||||
|
"dims": 1536,
|
||||||
|
"distance_metric": "cosine",
|
||||||
|
"algorithm": "flat",
|
||||||
|
"datatype": "float32",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.index = SearchIndex.from_dict(schema)
|
||||||
|
if redis_url is None:
|
||||||
|
# if no url passed, check if host, port and password are passed, if not raise an Exception
|
||||||
|
if host is None or port is None or password is None:
|
||||||
|
raise Exception(f"Redis host, port, and password must be provided")
|
||||||
|
redis_url = "redis://:" + password + "@" + host + ":" + port
|
||||||
|
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
||||||
|
self.index.connect(redis_url=redis_url)
|
||||||
|
self.index.create(overwrite=False) # don't overwrite existing index
|
||||||
|
|
||||||
# INDEX_NAME = 'idx:litellm_completion_response_vss'
|
def _get_cache_logic(self, cached_response: Any):
|
||||||
# DOC_PREFIX = 'bikes:'
|
"""
|
||||||
|
Common 'get_cache_logic' across sync + async redis client implementations
|
||||||
|
"""
|
||||||
|
if cached_response is None:
|
||||||
|
return cached_response
|
||||||
|
|
||||||
# try:
|
# check if cached_response is bytes
|
||||||
# # check to see if index exists
|
if isinstance(cached_response, bytes):
|
||||||
# client.ft(INDEX_NAME).info()
|
cached_response = cached_response.decode("utf-8")
|
||||||
# print('Index already exists!')
|
|
||||||
# except:
|
|
||||||
# # schema
|
|
||||||
# schema = (
|
|
||||||
# TextField('$.model', no_stem=True, as_name='model'),
|
|
||||||
# TextField('$.brand', no_stem=True, as_name='brand'),
|
|
||||||
# NumericField('$.price', as_name='price'),
|
|
||||||
# TagField('$.type', as_name='type'),
|
|
||||||
# TextField('$.description', as_name='description'),
|
|
||||||
# VectorField('$.description_embeddings',
|
|
||||||
# 'FLAT', {
|
|
||||||
# 'TYPE': 'FLOAT32',
|
|
||||||
# 'DIM': VECTOR_DIMENSION,
|
|
||||||
# 'DISTANCE_METRIC': 'COSINE',
|
|
||||||
# }, as_name='vector'
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # index Definition
|
try:
|
||||||
# definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.JSON)
|
cached_response = json.loads(
|
||||||
|
cached_response
|
||||||
# # create Index
|
) # Convert string to dictionary
|
||||||
# client.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response)
|
||||||
|
return cached_response
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
import numpy as np
|
||||||
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
|
||||||
try:
|
|
||||||
# get text response
|
|
||||||
# print("in redis semantic cache: value: ", value)
|
|
||||||
llm_response = value["response"]
|
|
||||||
|
|
||||||
# if llm_response is a string, convert it to a dictionary
|
print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
if isinstance(llm_response, str):
|
|
||||||
llm_response = json.loads(llm_response)
|
|
||||||
|
|
||||||
# print("converted llm_response: ", llm_response)
|
# get the prompt
|
||||||
response = llm_response["choices"][0]["message"]["content"]
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
# create embedding response
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# create an embedding for prompt
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = litellm.embedding(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
input=response,
|
input=prompt,
|
||||||
cache={"no-store": True},
|
cache={"no-store": True, "no-cache": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_embedding = embedding_response["data"][0]["embedding"]
|
# get the embedding
|
||||||
raw_embedding_dimension = len(raw_embedding)
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
# print("embedding: ", raw_embedding)
|
# make the embedding a numpy array, convert to bytes
|
||||||
key = "litellm-semantic:" + key
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
||||||
self.redis_client.json().set(
|
value = str(value)
|
||||||
name=key,
|
assert isinstance(value, str)
|
||||||
path="$",
|
|
||||||
obj=json.dumps(
|
|
||||||
{
|
|
||||||
"response": response,
|
|
||||||
"embedding": raw_embedding,
|
|
||||||
"dimension": raw_embedding_dimension,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
stored_redis_value = self.redis_client.json().get(name=key)
|
new_data = [
|
||||||
|
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
|
||||||
|
]
|
||||||
|
|
||||||
# print("Stored Redis Value: ", stored_redis_value)
|
# Add more data
|
||||||
|
keys = self.index.load(new_data)
|
||||||
|
|
||||||
except Exception as e:
|
pass
|
||||||
# print("Error occurred: ", e)
|
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
|
||||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
|
print_verbose(f"redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# query
|
||||||
|
|
||||||
|
# get the messages
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = litellm.embedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
query = VectorQuery(
|
||||||
|
vector=embedding,
|
||||||
|
vector_field_name="litellm_embedding",
|
||||||
|
return_fields=["response", "prompt", "vector_distance"],
|
||||||
|
num_results=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.index.query(query)
|
||||||
|
|
||||||
|
vector_distance = results[0]["vector_distance"]
|
||||||
|
vector_distance = float(vector_distance)
|
||||||
|
similarity = 1 - vector_distance
|
||||||
|
cached_prompt = results[0]["prompt"]
|
||||||
|
|
||||||
|
# check similarity, if more than self.similarity_threshold, return results
|
||||||
|
print_verbose(
|
||||||
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
if similarity > self.similarity_threshold:
|
||||||
|
# cache hit !
|
||||||
|
cached_value = results[0]["response"]
|
||||||
|
print_verbose(
|
||||||
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
return self._get_cache_logic(cached_response=cached_value)
|
||||||
|
else:
|
||||||
|
# cache miss !
|
||||||
|
return None
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
@ -527,6 +594,7 @@ class Cache:
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
port: Optional[str] = None,
|
port: Optional[str] = None,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
|
similarity_threshold: Optional[float] = None,
|
||||||
supported_call_types: Optional[
|
supported_call_types: Optional[
|
||||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||||
|
@ -547,10 +615,12 @@ class Cache:
|
||||||
Initializes the cache based on the given type.
|
Initializes the cache based on the given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
|
||||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||||
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic"
|
||||||
|
|
||||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||||
|
|
||||||
|
@ -563,7 +633,13 @@ class Cache:
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||||
elif type == "redis-semantic":
|
elif type == "redis-semantic":
|
||||||
self.cache = RedisSemanticCache(host, port, password, **kwargs)
|
self.cache = RedisSemanticCache(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
password,
|
||||||
|
similarity_threshold=similarity_threshold,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
elif type == "s3":
|
elif type == "s3":
|
||||||
|
@ -743,6 +819,7 @@ class Cache:
|
||||||
The cached result if it exists, otherwise None.
|
The cached result if it exists, otherwise None.
|
||||||
"""
|
"""
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
cache_key = kwargs["cache_key"]
|
cache_key = kwargs["cache_key"]
|
||||||
else:
|
else:
|
||||||
|
@ -752,7 +829,7 @@ class Cache:
|
||||||
max_age = cache_control_args.get(
|
max_age = cache_control_args.get(
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
)
|
)
|
||||||
cached_result = self.cache.get_cache(cache_key)
|
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||||
return self._get_cache_logic(
|
return self._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue