implemented RestAPI and added support for cloud and local Qdrant clusters

This commit is contained in:
Haadi Rakhangi 2024-08-19 20:46:30 +05:30
parent a047df3825
commit 9df92923d8
3 changed files with 238 additions and 129 deletions

View file

@ -146,11 +146,6 @@ assert response1.id == response2.id
<TabItem value="qdrant-sem" label="qdrant-semantic cache">
Install redis
```shell
pip install qdrant-client
```
You can set up your own cloud Qdrant cluster by following this: https://qdrant.tech/documentation/quickstart-cloud/
To set up a Qdrant cluster locally follow: https://qdrant.tech/documentation/quickstart/
@ -166,12 +161,12 @@ random_number = random.randint(
print("testing semantic caching")
litellm.cache = Cache(
type="qdrant-semantic",
qdrant_host_type="cloud", # can be either 'cloud' or 'local'
qdrant_url=os.environ["QDRANT_URL"],
qdrant_username=os.environ["QDRANT_USERNAME"]",
qdrant_password=os.environ["QDRANT_PASSWORD"],
qdrant_api_key=os.environ["QDRANT_API_KEY"],
qdrant_collection_name="your_collection_name", # any name of your collection
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
qdrant_quantization_config = "binary", # can be one of 'binary', 'product' or 'scalar' quantizations that is supported by qdrant
qdrant_quantization_config ="binary", # can be one of 'binary', 'product' or 'scalar' quantizations that is supported by qdrant
qdrant_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
)
@ -496,12 +491,12 @@ def __init__(
disk_cache_dir=None,
# qdrant cache params
qdrant_username: Optional[str] = None,
qdrant_password: Optional[str] = None,
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
**kwargs
):

View file

@ -1220,16 +1220,18 @@ class RedisSemanticCache(BaseCache):
class QdrantSemanticCache(BaseCache):
def __init__(
self,
qdrant_username=None,
qdrant_password=None,
qdrant_url=None,
qdrant_api_key = None,
collection_name=None,
similarity_threshold=None,
quantization_config=None,
embedding_model="text-embedding-ada-002"
embedding_model="text-embedding-ada-002",
host_type = None
):
from qdrant_client import models, AsyncQdrantClient, QdrantClient
import base64
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
_get_async_httpx_client
)
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
@ -1244,73 +1246,109 @@ class QdrantSemanticCache(BaseCache):
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
if qdrant_url is None or qdrant_username is None or qdrant_password is None:
if host_type=="cloud":
import os
if qdrant_url is None:
qdrant_url = os.getenv('QDRANT_URL')
if qdrant_api_key is None:
qdrant_api_key = os.getenv('QDRANT_API_KEY')
if qdrant_url is not None and qdrant_api_key is not None:
headers = {
"api-key": qdrant_api_key,
"Content-Type": "application/json"
}
else:
raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting")
elif host_type=="local":
import os
if qdrant_url is None:
qdrant_url = os.getenv('QDRANT_URL')
if qdrant_url is None:
raise Exception("Qdrant url must be provided for qdrant local hosting")
if qdrant_api_key is None:
qdrant_api_key = os.getenv('QDRANT_API_KEY')
if qdrant_api_key is None:
print_verbose('Running locally without API Key.')
headers= {
"Content-Type": "application/json"
}
else:
print_verbose("Running locally with API Key")
headers = {
"api-key": qdrant_api_key,
"Content-Type": "application/json"
}
else:
raise Exception("Host type can be either 'local' or 'cloud'")
qdrant_url = os.getenv('QDRANT_URL')
qdrant_username = os.getenv('QDRANT_USERNAME')
qdrant_password = os.getenv('QDRANT PASSWORD')
if qdrant_url is None or qdrant_username is None or qdrant_password is None:
raise Exception("Qdrant url, username and password must be provided")
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}")
print_verbose(f"qdrant semantic-cache qdrant_url: {qdrant_url}")
self.credentials = f"{qdrant_username}:{qdrant_password}"
self.encoded_credentials = base64.b64encode(self.credentials.encode()).decode()
self.headers = {
"Authorization": f"Basic {self.encoded_credentials}"
}
self.headers = headers
self.qdrant_client = QdrantClient(
url= qdrant_url,
timeout=1200,
headers=self.headers
)
self.sync_client = _get_httpx_client()
self.async_client = _get_async_httpx_client()
self.qdrant_client_async = AsyncQdrantClient(
url= qdrant_url,
timeout=1200,
headers=self.headers
)
if quantization_config is None:
print('Quantization config is not provided. Default binary quantization will be used.')
if self.qdrant_client.collection_exists(collection_name=f"{self.collection_name}"):
self.collection_info = self.qdrant_client.get_collection(f"{self.collection_name}")
collection_exists = self.sync_client.get(
url= f"{self.qdrant_url}/collections/{self.collection_name}/exists",
headers=self.headers
)
if collection_exists.json()['result']['exists']:
collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
headers=self.headers
)
self.collection_info = collection_details.json()
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}')
else:
if quantization_config is None or quantization_config == 'binary':
quantization_params = models.BinaryQuantization(
binary= models.BinaryQuantizationConfig(always_ram=False),
)
quantization_params = {
"binary": {
"always_ram": False,
}
}
elif quantization_config == 'scalar':
quantization_params = models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
quantile=0.99,
always_ram=False,
),
)
quantization_params = {
"scalar": {
"type": "int8",
"quantile": 0.99,
"always_ram": False
}
}
elif quantization_config == 'product':
quantization_params = models.ProductQuantization(
product=models.ProductQuantizationConfig(
compression=models.CompressionRatio.X16,
always_ram=False,
),
)
quantization_params = {
"product": {
"compression": "x16",
"always_ram": False
}
}
else:
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'")
self.qdrant_client.create_collection(
collection_name=f"{self.collection_name}",
vectors_config= models.VectorParams(
size=1536,
distance= models.Distance.COSINE
),
quantization_config= quantization_params
new_collection_status = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
json={
"vectors": {
"size": 1536,
"distance": "Cosine"
},
"quantization_config": quantization_params
},
headers=self.headers
)
self.collection_info = self.qdrant_client.get_collection(f"{self.collection_name}")
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
if new_collection_status.json()["result"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
headers=self.headers
)
self.collection_info = collection_details.json()
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
else:
raise Exception("Error while creating new collection")
def _get_cache_logic(self, cached_response: Any):
if cached_response is None:
@ -1325,7 +1363,6 @@ class QdrantSemanticCache(BaseCache):
def set_cache(self, key, value, **kwargs):
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
from qdrant_client import models
import uuid
# get the prompt
@ -1347,24 +1384,27 @@ class QdrantSemanticCache(BaseCache):
value = str(value)
assert isinstance(value, str)
keys = self.qdrant_client.upsert(
collection_name=f"{self.collection_name}",
points=[
models.PointStruct(
id=str(uuid.uuid4()),
payload={
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
vector= embedding,
),
}
},
]
}
keys = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
headers=self.headers,
json=data
)
return
def get_cache(self, key, **kwargs):
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
from qdrant_client import models
# get the messages
messages = kwargs["messages"]
@ -1382,19 +1422,25 @@ class QdrantSemanticCache(BaseCache):
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
results = self.qdrant_client.search(
collection_name=self.collection_name,
query_vector= embedding,
search_params= models.SearchParams(
quantization= models.QuantizationSearchParams(
ignore=False,
rescore=True,
oversampling=3.0,
),
exact=False,
),
limit=1
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit":1,
"with_payload": True
}
search_response = self.sync_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data
)
results = search_response.json()["result"]
if results == None:
return None
@ -1402,8 +1448,8 @@ class QdrantSemanticCache(BaseCache):
if len(results) == 0:
return None
similarity = results[0].score
cached_prompt = results[0].payload['text']
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
@ -1411,7 +1457,7 @@ class QdrantSemanticCache(BaseCache):
)
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0].payload['response']
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
@ -1423,7 +1469,6 @@ class QdrantSemanticCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
from litellm.proxy.proxy_server import llm_router, llm_model_list
from qdrant_client import models
import uuid
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
@ -1464,24 +1509,28 @@ class QdrantSemanticCache(BaseCache):
value = str(value)
assert isinstance(value, str)
keys = await self.qdrant_client_async.upsert(
collection_name=f"{self.collection_name}",
points=[
models.PointStruct(
id=str(uuid.uuid4()),
payload={
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
vector= embedding,
),
}
},
]
}
keys = await self.async_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
headers=self.headers,
json=data
)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from qdrant_client import models
from litellm.proxy.proxy_server import llm_router, llm_model_list
# get the messages
@ -1518,20 +1567,27 @@ class QdrantSemanticCache(BaseCache):
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
results = await self.qdrant_client_async.search(
collection_name=self.collection_name,
query_vector= embedding,
search_params= models.SearchParams(
quantization= models.QuantizationSearchParams(
ignore=False,
rescore=True,
oversampling=3.0,
),
exact=False,
),
limit=1
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit":1,
"with_payload": True
}
search_response = await self.async_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data
)
results = search_response.json()["result"]
if results == None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
@ -1540,8 +1596,8 @@ class QdrantSemanticCache(BaseCache):
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
similarity = results[0].score
cached_prompt = results[0].payload['text']
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
@ -1553,7 +1609,7 @@ class QdrantSemanticCache(BaseCache):
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0].payload['response']
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
@ -2070,12 +2126,12 @@ class Cache:
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
disk_cache_dir=None,
qdrant_username: Optional[str] = None,
qdrant_password: Optional[str] = None,
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
**kwargs,
):
"""
@ -2086,11 +2142,11 @@ class Cache:
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic"
qdrant_username (str, optional): The username for the qdrant cluster. Required if type is "qdrant-semantic"
qdrant_password (str, optional): The password for the qdrant cluster. Required if type is "qdrant-semantic"
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic"
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic"
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. Required if qdrant_host_type is "cloud" and optional if qdrant_host_type is "local".
qdrant_host_type (str, optional): Can be either "local" or "cloud". Should be "local" when you are running a local qdrant cluster or "cloud" when you are using a qdrant cloud cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache
@ -2117,13 +2173,13 @@ class Cache:
)
elif type == "qdrant-semantic":
self.cache = QdrantSemanticCache(
qdrant_username= qdrant_username,
qdrant_password= qdrant_password,
qdrant_url= qdrant_url,
qdrant_api_key= qdrant_api_key,
collection_name= qdrant_collection_name,
similarity_threshold= similarity_threshold,
quantization_config= qdrant_quantization_config,
embedding_model= qdrant_semantic_cache_embedding_model,
host_type=qdrant_host_type
)
elif type == "local":
self.cache = InMemoryCache()

View file

@ -114,6 +114,48 @@ class AsyncHTTPHandler:
except Exception as e:
raise e
async def put(
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
try:
req = self.client.build_request(
"PUT", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream)
response.raise_for_status()
return response
except (httpx.RemoteProtocolError, httpx.ConnectError):
# Retry the request with a new session if there is a connection error
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
try:
return await self.single_connection_post_request(
url=url,
client=new_client,
data=data,
json=json,
params=params,
headers=headers,
stream=stream,
)
finally:
await new_client.aclose()
except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code)
if stream is True:
setattr(e, "message", await e.response.aread())
else:
setattr(e, "message", e.response.text)
raise e
except Exception as e:
raise e
async def single_connection_post_request(
self,
url: str,
@ -200,6 +242,22 @@ class HTTPHandler:
response = self.client.send(req, stream=stream)
return response
def put(
self,
url: str,
data: Optional[Union[dict, str]] = None,
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"PUT", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response
def __del__(self) -> None:
try:
self.close()