mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #5018 from haadirakhangi/main
Qdrant Semantic Caching
This commit is contained in:
commit
a34aeafdb5
5 changed files with 694 additions and 6 deletions
|
@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache
|
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache, Qdrant Semantic
|
||||||
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
@ -144,7 +144,62 @@ assert response1.id == response2.id
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="qdrant-sem" label="qdrant-semantic cache">
|
||||||
|
|
||||||
|
You can set up your own cloud Qdrant cluster by following this: https://qdrant.tech/documentation/quickstart-cloud/
|
||||||
|
|
||||||
|
To set up a Qdrant cluster locally follow: https://qdrant.tech/documentation/quickstart/
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
from litellm.caching import Cache
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
|
||||||
|
print("testing semantic caching")
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="qdrant-semantic",
|
||||||
|
qdrant_host_type="cloud", # can be either 'cloud' or 'local'
|
||||||
|
qdrant_url=os.environ["QDRANT_URL"],
|
||||||
|
qdrant_api_key=os.environ["QDRANT_API_KEY"],
|
||||||
|
qdrant_collection_name="your_collection_name", # any name of your collection
|
||||||
|
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
||||||
|
qdrant_quantization_config ="binary", # can be one of 'binary', 'product' or 'scalar' quantizations that is supported by qdrant
|
||||||
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
|
||||||
|
)
|
||||||
|
|
||||||
|
response1 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
|
||||||
|
random_number = random.randint(1, 100000)
|
||||||
|
|
||||||
|
response2 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"response2: {response1}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
# response1 == response2, response 1 is cached
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="in-mem" label="in memory cache">
|
<TabItem value="in-mem" label="in memory cache">
|
||||||
|
|
||||||
|
@ -435,6 +490,14 @@ def __init__(
|
||||||
# disk cache params
|
# disk cache params
|
||||||
disk_cache_dir=None,
|
disk_cache_dir=None,
|
||||||
|
|
||||||
|
# qdrant cache params
|
||||||
|
qdrant_url: Optional[str] = None,
|
||||||
|
qdrant_api_key: Optional[str] = None,
|
||||||
|
qdrant_collection_name: Optional[str] = None,
|
||||||
|
qdrant_quantization_config: Optional[str] = None,
|
||||||
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
|
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
|
||||||
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
```
|
```
|
||||||
|
|
|
@ -1219,6 +1219,410 @@ class RedisSemanticCache(BaseCache):
|
||||||
async def _index_info(self):
|
async def _index_info(self):
|
||||||
return await self.index.ainfo()
|
return await self.index.ainfo()
|
||||||
|
|
||||||
|
class QdrantSemanticCache(BaseCache):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qdrant_url=None,
|
||||||
|
qdrant_api_key = None,
|
||||||
|
collection_name=None,
|
||||||
|
similarity_threshold=None,
|
||||||
|
quantization_config=None,
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
host_type = None
|
||||||
|
):
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
_get_async_httpx_client
|
||||||
|
)
|
||||||
|
|
||||||
|
if collection_name is None:
|
||||||
|
raise Exception("collection_name must be provided, passed None")
|
||||||
|
|
||||||
|
self.collection_name = collection_name
|
||||||
|
print_verbose(
|
||||||
|
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if similarity_threshold is None:
|
||||||
|
raise Exception("similarity_threshold must be provided, passed None")
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
|
||||||
|
if host_type=="cloud":
|
||||||
|
import os
|
||||||
|
if qdrant_url is None:
|
||||||
|
qdrant_url = os.getenv('QDRANT_URL')
|
||||||
|
if qdrant_api_key is None:
|
||||||
|
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
||||||
|
if qdrant_url is not None and qdrant_api_key is not None:
|
||||||
|
headers = {
|
||||||
|
"api-key": qdrant_api_key,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting")
|
||||||
|
elif host_type=="local":
|
||||||
|
import os
|
||||||
|
if qdrant_url is None:
|
||||||
|
qdrant_url = os.getenv('QDRANT_URL')
|
||||||
|
if qdrant_url is None:
|
||||||
|
raise Exception("Qdrant url must be provided for qdrant local hosting")
|
||||||
|
if qdrant_api_key is None:
|
||||||
|
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
||||||
|
if qdrant_api_key is None:
|
||||||
|
print_verbose('Running locally without API Key.')
|
||||||
|
headers= {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print_verbose("Running locally with API Key")
|
||||||
|
headers = {
|
||||||
|
"api-key": qdrant_api_key,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception("Host type can be either 'local' or 'cloud'")
|
||||||
|
|
||||||
|
self.qdrant_url = qdrant_url
|
||||||
|
self.qdrant_api_key = qdrant_api_key
|
||||||
|
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}")
|
||||||
|
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
|
self.sync_client = _get_httpx_client()
|
||||||
|
self.async_client = _get_async_httpx_client()
|
||||||
|
|
||||||
|
if quantization_config is None:
|
||||||
|
print('Quantization config is not provided. Default binary quantization will be used.')
|
||||||
|
|
||||||
|
collection_exists = self.sync_client.get(
|
||||||
|
url= f"{self.qdrant_url}/collections/{self.collection_name}/exists",
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
if collection_exists.json()['result']['exists']:
|
||||||
|
collection_details = self.sync_client.get(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
self.collection_info = collection_details.json()
|
||||||
|
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}')
|
||||||
|
else:
|
||||||
|
if quantization_config is None or quantization_config == 'binary':
|
||||||
|
quantization_params = {
|
||||||
|
"binary": {
|
||||||
|
"always_ram": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif quantization_config == 'scalar':
|
||||||
|
quantization_params = {
|
||||||
|
"scalar": {
|
||||||
|
"type": "int8",
|
||||||
|
"quantile": 0.99,
|
||||||
|
"always_ram": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif quantization_config == 'product':
|
||||||
|
quantization_params = {
|
||||||
|
"product": {
|
||||||
|
"compression": "x16",
|
||||||
|
"always_ram": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'")
|
||||||
|
|
||||||
|
new_collection_status = self.sync_client.put(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
||||||
|
json={
|
||||||
|
"vectors": {
|
||||||
|
"size": 1536,
|
||||||
|
"distance": "Cosine"
|
||||||
|
},
|
||||||
|
"quantization_config": quantization_params
|
||||||
|
},
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
if new_collection_status.json()["result"]:
|
||||||
|
collection_details = self.sync_client.get(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
self.collection_info = collection_details.json()
|
||||||
|
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
|
||||||
|
else:
|
||||||
|
raise Exception("Error while creating new collection")
|
||||||
|
|
||||||
|
def _get_cache_logic(self, cached_response: Any):
|
||||||
|
if cached_response is None:
|
||||||
|
return cached_response
|
||||||
|
try:
|
||||||
|
cached_response = json.loads(
|
||||||
|
cached_response
|
||||||
|
) # Convert string to dictionary
|
||||||
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response)
|
||||||
|
return cached_response
|
||||||
|
|
||||||
|
def set_cache(self, key, value, **kwargs):
|
||||||
|
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# get the prompt
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# create an embedding for prompt
|
||||||
|
embedding_response = litellm.embedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
value = str(value)
|
||||||
|
assert isinstance(value, str)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"points": [
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"vector": embedding,
|
||||||
|
"payload": {
|
||||||
|
"text": prompt,
|
||||||
|
"response": value,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
keys = self.sync_client.put(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
|
||||||
|
headers=self.headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_cache(self, key, **kwargs):
|
||||||
|
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
|
# get the messages
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = litellm.embedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"vector": embedding,
|
||||||
|
"params": {
|
||||||
|
"quantization": {
|
||||||
|
"ignore": False,
|
||||||
|
"rescore": True,
|
||||||
|
"oversampling": 3.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"limit":1,
|
||||||
|
"with_payload": True
|
||||||
|
}
|
||||||
|
|
||||||
|
search_response = self.sync_client.post(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
|
||||||
|
headers=self.headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
results = search_response.json()["result"]
|
||||||
|
|
||||||
|
if results == None:
|
||||||
|
return None
|
||||||
|
if isinstance(results, list):
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
similarity = results[0]["score"]
|
||||||
|
cached_prompt = results[0]["payload"]["text"]
|
||||||
|
|
||||||
|
# check similarity, if more than self.similarity_threshold, return results
|
||||||
|
print_verbose(
|
||||||
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
if similarity >= self.similarity_threshold:
|
||||||
|
# cache hit !
|
||||||
|
cached_value = results[0]["payload"]["response"]
|
||||||
|
print_verbose(
|
||||||
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
return self._get_cache_logic(cached_response=cached_value)
|
||||||
|
else:
|
||||||
|
# cache miss !
|
||||||
|
return None
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||||
|
import uuid
|
||||||
|
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
|
# get the prompt
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
# create an embedding for prompt
|
||||||
|
router_model_names = (
|
||||||
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
if llm_model_list is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if llm_router is not None and self.embedding_model in router_model_names:
|
||||||
|
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||||
|
embedding_response = await llm_router.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
metadata={
|
||||||
|
"user_api_key": user_api_key,
|
||||||
|
"semantic-cache-embedding": True,
|
||||||
|
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
value = str(value)
|
||||||
|
assert isinstance(value, str)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"points": [
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"vector": embedding,
|
||||||
|
"payload": {
|
||||||
|
"text": prompt,
|
||||||
|
"response": value,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
keys = await self.async_client.put(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
|
||||||
|
headers=self.headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
|
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||||
|
|
||||||
|
# get the messages
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
router_model_names = (
|
||||||
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
if llm_model_list is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if llm_router is not None and self.embedding_model in router_model_names:
|
||||||
|
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||||
|
embedding_response = await llm_router.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
metadata={
|
||||||
|
"user_api_key": user_api_key,
|
||||||
|
"semantic-cache-embedding": True,
|
||||||
|
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"vector": embedding,
|
||||||
|
"params": {
|
||||||
|
"quantization": {
|
||||||
|
"ignore": False,
|
||||||
|
"rescore": True,
|
||||||
|
"oversampling": 3.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"limit":1,
|
||||||
|
"with_payload": True
|
||||||
|
}
|
||||||
|
|
||||||
|
search_response = await self.async_client.post(
|
||||||
|
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
|
||||||
|
headers=self.headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_response.json()["result"]
|
||||||
|
|
||||||
|
if results == None:
|
||||||
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||||
|
return None
|
||||||
|
if isinstance(results, list):
|
||||||
|
if len(results) == 0:
|
||||||
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||||
|
return None
|
||||||
|
|
||||||
|
similarity = results[0]["score"]
|
||||||
|
cached_prompt = results[0]["payload"]["text"]
|
||||||
|
|
||||||
|
# check similarity, if more than self.similarity_threshold, return results
|
||||||
|
print_verbose(
|
||||||
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||||
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||||
|
|
||||||
|
if similarity >= self.similarity_threshold:
|
||||||
|
# cache hit !
|
||||||
|
cached_value = results[0]["payload"]["response"]
|
||||||
|
print_verbose(
|
||||||
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
return self._get_cache_logic(cached_response=cached_value)
|
||||||
|
else:
|
||||||
|
# cache miss !
|
||||||
|
return None
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _collection_info(self):
|
||||||
|
return self.collection_info
|
||||||
|
|
||||||
class S3Cache(BaseCache):
|
class S3Cache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1676,7 +2080,7 @@ class Cache:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type: Optional[
|
type: Optional[
|
||||||
Literal["local", "redis", "redis-semantic", "s3", "disk"]
|
Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]
|
||||||
] = "local",
|
] = "local",
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
port: Optional[str] = None,
|
port: Optional[str] = None,
|
||||||
|
@ -1725,17 +2129,27 @@ class Cache:
|
||||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
redis_flush_size=None,
|
redis_flush_size=None,
|
||||||
disk_cache_dir=None,
|
disk_cache_dir=None,
|
||||||
|
qdrant_url: Optional[str] = None,
|
||||||
|
qdrant_api_key: Optional[str] = None,
|
||||||
|
qdrant_collection_name: Optional[str] = None,
|
||||||
|
qdrant_quantization_config: Optional[str] = None,
|
||||||
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
|
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the cache based on the given type.
|
Initializes the cache based on the given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
|
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic"
|
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||||
|
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. Required if qdrant_host_type is "cloud" and optional if qdrant_host_type is "local".
|
||||||
|
qdrant_host_type (str, optional): Can be either "local" or "cloud". Should be "local" when you are running a local qdrant cluster or "cloud" when you are using a qdrant cloud cluster.
|
||||||
|
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||||
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||||
|
|
||||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||||
|
@ -1760,6 +2174,16 @@ class Cache:
|
||||||
embedding_model=redis_semantic_cache_embedding_model,
|
embedding_model=redis_semantic_cache_embedding_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
elif type == "qdrant-semantic":
|
||||||
|
self.cache = QdrantSemanticCache(
|
||||||
|
qdrant_url= qdrant_url,
|
||||||
|
qdrant_api_key= qdrant_api_key,
|
||||||
|
collection_name= qdrant_collection_name,
|
||||||
|
similarity_threshold= similarity_threshold,
|
||||||
|
quantization_config= qdrant_quantization_config,
|
||||||
|
embedding_model= qdrant_semantic_cache_embedding_model,
|
||||||
|
host_type=qdrant_host_type
|
||||||
|
)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
elif type == "s3":
|
elif type == "s3":
|
||||||
|
|
|
@ -129,6 +129,62 @@ class AsyncHTTPHandler:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def put(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||||
|
json: Optional[dict] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if timeout is None:
|
||||||
|
timeout = self.timeout
|
||||||
|
req = self.client.build_request(
|
||||||
|
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||||
|
)
|
||||||
|
response = await self.client.send(req, stream=stream)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
||||||
|
# Retry the request with a new session if there is a connection error
|
||||||
|
new_client = self.create_client(timeout=timeout, concurrent_limit=1)
|
||||||
|
try:
|
||||||
|
return await self.single_connection_post_request(
|
||||||
|
url=url,
|
||||||
|
client=new_client,
|
||||||
|
data=data,
|
||||||
|
json=json,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await new_client.aclose()
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
headers = {}
|
||||||
|
if hasattr(e, "response") and e.response is not None:
|
||||||
|
for key, value in e.response.headers.items():
|
||||||
|
headers["response_headers-{}".format(key)] = value
|
||||||
|
|
||||||
|
raise litellm.Timeout(
|
||||||
|
message=f"Connection timed out after {timeout} seconds.",
|
||||||
|
model="default-model-name",
|
||||||
|
llm_provider="litellm-httpx-handler",
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
if stream is True:
|
||||||
|
setattr(e, "message", await e.response.aread())
|
||||||
|
else:
|
||||||
|
setattr(e, "message", e.response.text)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
async def delete(
|
async def delete(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
|
@ -274,6 +330,38 @@ class HTTPHandler:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
data: Optional[Union[dict, str]] = None,
|
||||||
|
json: Optional[Union[dict, str]] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
|
||||||
|
if timeout is not None:
|
||||||
|
req = self.client.build_request(
|
||||||
|
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
req = self.client.build_request(
|
||||||
|
"PUT", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
|
)
|
||||||
|
response = self.client.send(req, stream=stream)
|
||||||
|
return response
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise litellm.Timeout(
|
||||||
|
message=f"Connection timed out after {timeout} seconds.",
|
||||||
|
model="default-model-name",
|
||||||
|
llm_provider="litellm-httpx-handler",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
try:
|
try:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
|
@ -1732,3 +1732,108 @@ def test_caching_redis_simple(caplog, capsys):
|
||||||
assert redis_async_caching_error is False
|
assert redis_async_caching_error is False
|
||||||
assert redis_service_logging_error is False
|
assert redis_service_logging_error is False
|
||||||
assert "async success_callback: reaches cache for logging" not in captured.out
|
assert "async success_callback: reaches cache for logging" not in captured.out
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qdrant_semantic_cache_acompletion():
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding /reading from cache
|
||||||
|
|
||||||
|
print("Testing Qdrant Semantic Caching with acompletion")
|
||||||
|
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="qdrant-semantic",
|
||||||
|
qdrant_host_type="cloud",
|
||||||
|
qdrant_url=os.getenv("QDRANT_URL"),
|
||||||
|
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||||
|
qdrant_collection_name='test_collection',
|
||||||
|
similarity_threshold=0.8,
|
||||||
|
qdrant_quantization_config="binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
response1 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"Response1: {response1}")
|
||||||
|
|
||||||
|
random_number = random.randint(1, 100000)
|
||||||
|
|
||||||
|
response2 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"Response2: {response2}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qdrant_semantic_cache_acompletion_stream():
|
||||||
|
try:
|
||||||
|
random_word = generate_random_word()
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a joke about: {random_word}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="qdrant-semantic",
|
||||||
|
qdrant_host_type="cloud",
|
||||||
|
qdrant_url=os.getenv("QDRANT_URL"),
|
||||||
|
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||||
|
qdrant_collection_name='test_collection',
|
||||||
|
similarity_threshold=0.8,
|
||||||
|
qdrant_quantization_config="binary"
|
||||||
|
)
|
||||||
|
print("Test Qdrant Semantic Caching with streaming + acompletion")
|
||||||
|
response_1_content = ""
|
||||||
|
response_2_content = ""
|
||||||
|
|
||||||
|
response1 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=40,
|
||||||
|
temperature=1,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in response1:
|
||||||
|
response_1_id = chunk.id
|
||||||
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
response2 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=40,
|
||||||
|
temperature=1,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in response2:
|
||||||
|
response_2_id = chunk.id
|
||||||
|
response_2_content += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("\nResponse 1", response_1_content, "\nResponse 1 id", response_1_id)
|
||||||
|
print("\nResponse 2", response_2_content, "\nResponse 2 id", response_2_id)
|
||||||
|
assert (
|
||||||
|
response_1_content == response_2_content
|
||||||
|
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
|
assert (response_1_id == response_2_id), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}"
|
||||||
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{str(e)}\n\n{traceback.format_exc()}")
|
||||||
|
raise e
|
||||||
|
|
|
@ -121,7 +121,7 @@ import importlib.metadata
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching import RedisCache, RedisSemanticCache, S3Cache
|
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
|
@ -1164,6 +1164,14 @@ def client(original_function):
|
||||||
cached_result = await litellm.cache.async_get_cache(
|
cached_result = await litellm.cache.async_get_cache(
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
)
|
)
|
||||||
|
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
|
kwargs["preset_cache_key"] = (
|
||||||
|
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||||
|
)
|
||||||
|
cached_result = await litellm.cache.async_get_cache(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
kwargs["preset_cache_key"] = (
|
kwargs["preset_cache_key"] = (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue