fix drant url

This commit is contained in:
Ishaan Jaff 2024-08-21 12:07:57 -07:00
parent 8812da04e3
commit 428a74be07
2 changed files with 104 additions and 117 deletions

View file

@ -1219,25 +1219,28 @@ class RedisSemanticCache(BaseCache):
async def _index_info(self):
return await self.index.ainfo()
class QdrantSemanticCache(BaseCache):
def __init__(
self,
qdrant_url=None,
qdrant_api_key = None,
collection_name=None,
similarity_threshold=None,
quantization_config=None,
embedding_model="text-embedding-ada-002",
host_type = None
):
self,
qdrant_url=None,
qdrant_api_key=None,
collection_name=None,
similarity_threshold=None,
quantization_config=None,
embedding_model="text-embedding-ada-002",
host_type=None,
):
import os
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
_get_async_httpx_client
_get_async_httpx_client,
_get_httpx_client,
)
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
self.collection_name = collection_name
print_verbose(
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
@ -1247,108 +1250,84 @@ class QdrantSemanticCache(BaseCache):
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
if host_type=="cloud":
import os
if qdrant_url is None:
qdrant_url = os.getenv('QDRANT_URL')
if qdrant_api_key is None:
qdrant_api_key = os.getenv('QDRANT_API_KEY')
if qdrant_url is not None and qdrant_api_key is not None:
headers = {
"api-key": qdrant_api_key,
"Content-Type": "application/json"
}
else:
raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting")
elif host_type=="local":
import os
if qdrant_url is None:
qdrant_url = os.getenv('QDRANT_URL')
if qdrant_url is None:
raise Exception("Qdrant url must be provided for qdrant local hosting")
if qdrant_api_key is None:
qdrant_api_key = os.getenv('QDRANT_API_KEY')
if qdrant_api_key is None:
print_verbose('Running locally without API Key.')
headers= {
"Content-Type": "application/json"
}
else:
print_verbose("Running locally with API Key")
headers = {
"api-key": qdrant_api_key,
"Content-Type": "application/json"
}
headers = {}
if qdrant_url is None:
qdrant_url = os.getenv("QDRANT_URL")
if qdrant_api_key is None:
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if qdrant_url is not None and qdrant_api_key is not None:
headers = {"api-key": qdrant_api_key, "Content-Type": "application/json"}
else:
raise Exception("Host type can be either 'local' or 'cloud'")
raise Exception("Qdrant url and api_key must be")
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}")
self.headers = headers
self.sync_client = _get_httpx_client()
self.async_client = _get_async_httpx_client()
if quantization_config is None:
print('Quantization config is not provided. Default binary quantization will be used.')
collection_exists = self.sync_client.get(
url= f"{self.qdrant_url}/collections/{self.collection_name}/exists",
headers=self.headers
print_verbose(
"Quantization config is not provided. Default binary quantization will be used."
)
if collection_exists.json()['result']['exists']:
collection_exists = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}/exists",
headers=self.headers,
)
if collection_exists.status_code != 200:
raise ValueError(
f"Error from qdrant checking if /collections exist {collection_exists.text}"
)
if collection_exists.json()["result"]["exists"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
headers=self.headers
)
url=f"{self.qdrant_url}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}')
print_verbose(
f"Collection already exists.\nCollection details:{self.collection_info}"
)
else:
if quantization_config is None or quantization_config == 'binary':
if quantization_config is None or quantization_config == "binary":
quantization_params = {
"binary": {
"always_ram": False,
}
}
elif quantization_config == 'scalar':
elif quantization_config == "scalar":
quantization_params = {
"scalar": {
"type": "int8",
"quantile": 0.99,
"always_ram": False
}
"scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
}
elif quantization_config == 'product':
elif quantization_config == "product":
quantization_params = {
"product": {
"compression": "x16",
"always_ram": False
}
"product": {"compression": "x16", "always_ram": False}
}
else:
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'")
else:
raise Exception(
"Quantization config must be one of 'scalar', 'binary' or 'product'"
)
new_collection_status = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
json={
"vectors": {
"size": 1536,
"distance": "Cosine"
},
"quantization_config": quantization_params
"vectors": {"size": 1536, "distance": "Cosine"},
"quantization_config": quantization_params,
},
headers=self.headers
headers=self.headers,
)
if new_collection_status.json()["result"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
headers=self.headers
)
url=f"{self.qdrant_url}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
print_verbose(
f"New collection created.\nCollection details:{self.collection_info}"
)
else:
raise Exception("Error while creating new collection")
@ -1394,14 +1373,14 @@ class QdrantSemanticCache(BaseCache):
"payload": {
"text": prompt,
"response": value,
}
},
},
]
}
keys = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
headers=self.headers,
json=data
json=data,
)
return
@ -1433,14 +1412,14 @@ class QdrantSemanticCache(BaseCache):
"oversampling": 3.0,
}
},
"limit":1,
"with_payload": True
"limit": 1,
"with_payload": True,
}
search_response = self.sync_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data
json=data,
)
results = search_response.json()["result"]
@ -1470,8 +1449,10 @@ class QdrantSemanticCache(BaseCache):
pass
async def async_set_cache(self, key, value, **kwargs):
from litellm.proxy.proxy_server import llm_router, llm_model_list
import uuid
from litellm.proxy.proxy_server import llm_model_list, llm_router
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
@ -1519,7 +1500,7 @@ class QdrantSemanticCache(BaseCache):
"payload": {
"text": prompt,
"response": value,
}
},
},
]
}
@ -1527,13 +1508,13 @@ class QdrantSemanticCache(BaseCache):
keys = await self.async_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
headers=self.headers,
json=data
json=data,
)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from litellm.proxy.proxy_server import llm_router, llm_model_list
from litellm.proxy.proxy_server import llm_model_list, llm_router
# get the messages
messages = kwargs["messages"]
@ -1578,14 +1559,14 @@ class QdrantSemanticCache(BaseCache):
"oversampling": 3.0,
}
},
"limit":1,
"with_payload": True
"limit": 1,
"with_payload": True,
}
search_response = await self.async_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data
json=data,
)
results = search_response.json()["result"]
@ -1624,6 +1605,7 @@ class QdrantSemanticCache(BaseCache):
async def _collection_info(self):
return self.collection_info
class S3Cache(BaseCache):
def __init__(
self,
@ -2134,7 +2116,7 @@ class Cache:
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
qdrant_host_type: Optional[Literal["local", "cloud"]] = "local",
**kwargs,
):
"""
@ -2176,13 +2158,13 @@ class Cache:
)
elif type == "qdrant-semantic":
self.cache = QdrantSemanticCache(
qdrant_url= qdrant_url,
qdrant_api_key= qdrant_api_key,
collection_name= qdrant_collection_name,
similarity_threshold= similarity_threshold,
quantization_config= qdrant_quantization_config,
embedding_model= qdrant_semantic_cache_embedding_model,
host_type=qdrant_host_type
qdrant_url=qdrant_url,
qdrant_api_key=qdrant_api_key,
collection_name=qdrant_collection_name,
similarity_threshold=similarity_threshold,
quantization_config=qdrant_quantization_config,
embedding_model=qdrant_semantic_cache_embedding_model,
host_type=qdrant_host_type,
)
elif type == "local":
self.cache = InMemoryCache()