forked from phoenix/litellm-mirror
qdrant semantic caching added
This commit is contained in:
parent
c64b44aa0e
commit
851db5ecea
3 changed files with 449 additions and 5 deletions
|
@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht
|
|||
|
||||
:::
|
||||
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache, Qdrant Semantic
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
@ -144,7 +144,67 @@ assert response1.id == response2.id
|
|||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="qdrant-sem" label="qdrant-semantic cache">
|
||||
|
||||
Install redis
|
||||
```shell
|
||||
pip install qdrant-client
|
||||
```
|
||||
|
||||
You can set up your own cloud Qdrant cluster by following this: https://qdrant.tech/documentation/quickstart-cloud/
|
||||
|
||||
To set up a Qdrant cluster locally follow: https://qdrant.tech/documentation/quickstart/
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm.caching import Cache
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
|
||||
print("testing semantic caching")
|
||||
litellm.cache = Cache(
|
||||
type="qdrant-semantic",
|
||||
qdrant_url=os.environ["QDRANT_URL"],
|
||||
qdrant_username=os.environ["QDRANT_USERNAME"]",
|
||||
qdrant_password=os.environ["QDRANT_PASSWORD"],
|
||||
qdrant_collection_name="your_collection_name", # any name of your collection
|
||||
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
||||
qdrant_quantization_config = "binary", # can be one of 'binary', 'product' or 'scalar' quantizations that is supported by qdrant
|
||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
|
||||
)
|
||||
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
|
||||
random_number = random.randint(1, 100000)
|
||||
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"response2: {response1}")
|
||||
assert response1.id == response2.id
|
||||
# response1 == response2, response 1 is cached
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="in-mem" label="in memory cache">
|
||||
|
||||
|
@ -435,6 +495,14 @@ def __init__(
|
|||
# disk cache params
|
||||
disk_cache_dir=None,
|
||||
|
||||
# qdrant cache params
|
||||
qdrant_username: Optional[str] = None,
|
||||
qdrant_password: Optional[str] = None,
|
||||
qdrant_url: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
|
||||
**kwargs
|
||||
):
|
||||
```
|
||||
|
|
|
@ -1217,6 +1217,354 @@ class RedisSemanticCache(BaseCache):
|
|||
async def _index_info(self):
|
||||
return await self.index.ainfo()
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_username=None,
|
||||
qdrant_password=None,
|
||||
qdrant_url=None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002"
|
||||
):
|
||||
from qdrant_client import models, AsyncQdrantClient, QdrantClient
|
||||
import base64
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
||||
self.collection_name = collection_name
|
||||
print_verbose(
|
||||
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||
)
|
||||
|
||||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
if qdrant_url is None or qdrant_username is None or qdrant_password is None:
|
||||
import os
|
||||
|
||||
qdrant_url = os.getenv('QDRANT_URL')
|
||||
qdrant_username = os.getenv('QDRANT_USERNAME')
|
||||
qdrant_password = os.getenv('QDRANT PASSWORD')
|
||||
if qdrant_url is None or qdrant_username is None or qdrant_password is None:
|
||||
raise Exception("Qdrant url, username and password must be provided")
|
||||
|
||||
print_verbose(f"qdrant semantic-cache qdrant_url: {qdrant_url}")
|
||||
self.credentials = f"{qdrant_username}:{qdrant_password}"
|
||||
self.encoded_credentials = base64.b64encode(self.credentials.encode()).decode()
|
||||
self.headers = {
|
||||
"Authorization": f"Basic {self.encoded_credentials}"
|
||||
}
|
||||
|
||||
self.qdrant_client = QdrantClient(
|
||||
url= qdrant_url,
|
||||
timeout=1200,
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
self.qdrant_client_async = AsyncQdrantClient(
|
||||
url= qdrant_url,
|
||||
timeout=1200,
|
||||
headers=self.headers
|
||||
)
|
||||
if quantization_config is None:
|
||||
print('Quantization config is not provided. Default binary quantization will be used.')
|
||||
|
||||
if self.qdrant_client.collection_exists(collection_name=f"{self.collection_name}"):
|
||||
self.collection_info = self.qdrant_client.get_collection(f"{self.collection_name}")
|
||||
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}')
|
||||
else:
|
||||
if quantization_config is None or quantization_config == 'binary':
|
||||
quantization_params = models.BinaryQuantization(
|
||||
binary= models.BinaryQuantizationConfig(always_ram=False),
|
||||
)
|
||||
elif quantization_config == 'scalar':
|
||||
quantization_params = models.ScalarQuantization(
|
||||
scalar=models.ScalarQuantizationConfig(
|
||||
type=models.ScalarType.INT8,
|
||||
quantile=0.99,
|
||||
always_ram=False,
|
||||
),
|
||||
)
|
||||
elif quantization_config == 'product':
|
||||
quantization_params = models.ProductQuantization(
|
||||
product=models.ProductQuantizationConfig(
|
||||
compression=models.CompressionRatio.X16,
|
||||
always_ram=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'")
|
||||
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=f"{self.collection_name}",
|
||||
vectors_config= models.VectorParams(
|
||||
size=1536,
|
||||
distance= models.Distance.COSINE
|
||||
),
|
||||
quantization_config= quantization_params
|
||||
)
|
||||
|
||||
self.collection_info = self.qdrant_client.get_collection(f"{self.collection_name}")
|
||||
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
from qdrant_client import models
|
||||
import uuid
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
keys = self.qdrant_client.upsert(
|
||||
collection_name=f"{self.collection_name}",
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
payload={
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
vector= embedding,
|
||||
),
|
||||
]
|
||||
)
|
||||
return
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from qdrant_client import models
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
results = self.qdrant_client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector= embedding,
|
||||
search_params= models.SearchParams(
|
||||
quantization= models.QuantizationSearchParams(
|
||||
ignore=False,
|
||||
rescore=True,
|
||||
oversampling=3.0,
|
||||
),
|
||||
exact=False,
|
||||
),
|
||||
limit=1
|
||||
)
|
||||
|
||||
if results == None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
similarity = results[0].score
|
||||
cached_prompt = results[0].payload['text']
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0].payload['response']
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
from qdrant_client import models
|
||||
import uuid
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
keys = await self.qdrant_client_async.upsert(
|
||||
collection_name=f"{self.collection_name}",
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
payload={
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
vector= embedding,
|
||||
),
|
||||
]
|
||||
)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from qdrant_client import models
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
results = await self.qdrant_client_async.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector= embedding,
|
||||
search_params= models.SearchParams(
|
||||
quantization= models.QuantizationSearchParams(
|
||||
ignore=False,
|
||||
rescore=True,
|
||||
oversampling=3.0,
|
||||
),
|
||||
exact=False,
|
||||
),
|
||||
limit=1
|
||||
)
|
||||
|
||||
if results == None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
similarity = results[0].score
|
||||
cached_prompt = results[0].payload['text']
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0].payload['response']
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
|
@ -1673,7 +2021,7 @@ class Cache:
|
|||
def __init__(
|
||||
self,
|
||||
type: Optional[
|
||||
Literal["local", "redis", "redis-semantic", "s3", "disk"]
|
||||
Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]
|
||||
] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
|
@ -1722,17 +2070,27 @@ class Cache:
|
|||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
disk_cache_dir=None,
|
||||
qdrant_username: Optional[str] = None,
|
||||
qdrant_password: Optional[str] = None,
|
||||
qdrant_url: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic"
|
||||
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic"
|
||||
qdrant_username (str, optional): The username for the qdrant cluster. Required if type is "qdrant-semantic"
|
||||
qdrant_password (str, optional): The password for the qdrant cluster. Required if type is "qdrant-semantic"
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic"
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic"
|
||||
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
@ -1757,6 +2115,16 @@ class Cache:
|
|||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "qdrant-semantic":
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_username= qdrant_username,
|
||||
qdrant_password= qdrant_password,
|
||||
qdrant_url= qdrant_url,
|
||||
collection_name= qdrant_collection_name,
|
||||
similarity_threshold= similarity_threshold,
|
||||
quantization_config= qdrant_quantization_config,
|
||||
embedding_model= qdrant_semantic_cache_embedding_model,
|
||||
)
|
||||
elif type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
elif type == "s3":
|
||||
|
|
|
@ -113,7 +113,7 @@ import importlib.metadata
|
|||
from openai import OpenAIError as OriginalError
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching import RedisCache, RedisSemanticCache, S3Cache
|
||||
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache
|
||||
from .exceptions import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
|
@ -1114,6 +1114,14 @@ def client(original_function):
|
|||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue