fix qdrant litellm on proxy

This commit is contained in:
Ishaan Jaff 2024-08-21 12:52:29 -07:00
parent c6dfd2d276
commit e7ecb2fe3a
5 changed files with 84 additions and 28 deletions

View file

@ -161,7 +161,7 @@ random_number = random.randint(
print("testing semantic caching")
litellm.cache = Cache(
type="qdrant-semantic",
qdrant_url=os.environ["QDRANT_API_BASE"],
qdrant_api_base=os.environ["QDRANT_API_BASE"],
qdrant_api_key=os.environ["QDRANT_API_KEY"],
qdrant_collection_name="your_collection_name", # any name of your collection
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
@ -490,7 +490,7 @@ def __init__(
disk_cache_dir=None,
# qdrant cache params
qdrant_url: Optional[str] = None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,

View file

@ -7,6 +7,7 @@ Cache LLM Responses
LiteLLM supports:
- In Memory Cache
- Redis Cache
- Qdrant Semantic Cache
- Redis Semantic Cache
- s3 Bucket Cache
@ -182,6 +183,49 @@ REDIS_<redis-kwarg-name> = ""
$ litellm --config /path/to/config.yaml
```
</TabItem>
<TabItem value="qdrant-semantic" label="Qdrant Semantic cache">
Caching can be enabled by adding the `cache` key in the `config.yaml`
#### Step 1: Add `cache` to the config.yaml
```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: openai-embedding
litellm_params:
model: openai/text-embedding-3-small
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
set_verbose: True
cache: True # set cache responses to True, litellm defaults to using a redis cache
cache_params:
type: qdrant-semantic
qdrant_semantic_cache_embedding_model: openai-embedding # the model should be defined on the model_list
qdrant_collection_name: test_collection
qdrant_quantization_config: binary
similarity_threshold: 0.8 # similarity threshold for semantic cache
```
#### Step 2: Add Qdrant Credentials to your .env
```shell
QDRANT_API_KEY = "16rJUMBRx*************"
QDRANT_API_BASE = "https://5392d382-45*********.cloud.qdrant.io"
```
#### Step 3: Run proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
</TabItem>
</Tabs>

View file

@ -74,6 +74,7 @@ const sidebars = {
"proxy/alerting",
"proxy/ui",
"proxy/prometheus",
"proxy/caching",
"proxy/pass_through",
"proxy/email",
"proxy/multiple_admins",
@ -88,7 +89,6 @@ const sidebars = {
"proxy/health",
"proxy/debugging",
"proxy/pii_masking",
"proxy/caching",
"proxy/call_hooks",
"proxy/rules",
"proxy/cli",

View file

@ -1223,7 +1223,7 @@ class RedisSemanticCache(BaseCache):
class QdrantSemanticCache(BaseCache):
def __init__(
self,
qdrant_url=None,
qdrant_api_base=None,
qdrant_api_key=None,
collection_name=None,
similarity_threshold=None,
@ -1251,18 +1251,31 @@ class QdrantSemanticCache(BaseCache):
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
headers = {}
if qdrant_url is None:
qdrant_url = os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
if qdrant_api_key is None:
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if qdrant_url is not None and qdrant_api_key is not None:
headers = {"api-key": qdrant_api_key, "Content-Type": "application/json"}
else:
raise Exception("Qdrant url and api_key must be")
self.qdrant_url = qdrant_url
# check if defined as os.environ/ variable
if qdrant_api_base:
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
"os.environ/"
):
qdrant_api_base = litellm.get_secret(qdrant_api_base)
if qdrant_api_key:
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"os.environ/"
):
qdrant_api_key = litellm.get_secret(qdrant_api_key)
qdrant_api_base = (
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
)
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
headers = {"api-key": qdrant_api_key, "Content-Type": "application/json"}
if qdrant_api_key is None or qdrant_api_base is None:
raise ValueError("Qdrant url and api_key must be")
self.qdrant_api_base = qdrant_api_base
self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}")
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
self.headers = headers
@ -1274,7 +1287,7 @@ class QdrantSemanticCache(BaseCache):
"Quantization config is not provided. Default binary quantization will be used."
)
collection_exists = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}/exists",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
headers=self.headers,
)
if collection_exists.status_code != 200:
@ -1284,7 +1297,7 @@ class QdrantSemanticCache(BaseCache):
if collection_exists.json()["result"]["exists"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
@ -1312,7 +1325,7 @@ class QdrantSemanticCache(BaseCache):
)
new_collection_status = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
json={
"vectors": {"size": 1536, "distance": "Cosine"},
"quantization_config": quantization_params,
@ -1321,7 +1334,7 @@ class QdrantSemanticCache(BaseCache):
)
if new_collection_status.json()["result"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
@ -1378,7 +1391,7 @@ class QdrantSemanticCache(BaseCache):
]
}
keys = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
@ -1417,7 +1430,7 @@ class QdrantSemanticCache(BaseCache):
}
search_response = self.sync_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
@ -1506,7 +1519,7 @@ class QdrantSemanticCache(BaseCache):
}
keys = await self.async_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
@ -1564,7 +1577,7 @@ class QdrantSemanticCache(BaseCache):
}
search_response = await self.async_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
@ -2111,7 +2124,7 @@ class Cache:
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
disk_cache_dir=None,
qdrant_url: Optional[str] = None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
@ -2126,7 +2139,7 @@ class Cache:
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
@ -2156,7 +2169,7 @@ class Cache:
)
elif type == "qdrant-semantic":
self.cache = QdrantSemanticCache(
qdrant_url=qdrant_url,
qdrant_api_base=qdrant_api_base,
qdrant_api_key=qdrant_api_key,
collection_name=qdrant_collection_name,
similarity_threshold=similarity_threshold,

View file

@ -1746,7 +1746,7 @@ async def test_qdrant_semantic_cache_acompletion():
litellm.cache = Cache(
type="qdrant-semantic",
_host_type="cloud",
qdrant_url=os.getenv("QDRANT_URL"),
qdrant_api_base=os.getenv("QDRANT_URL"),
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
qdrant_collection_name="test_collection",
similarity_threshold=0.8,
@ -1794,8 +1794,7 @@ async def test_qdrant_semantic_cache_acompletion_stream():
]
litellm.cache = Cache(
type="qdrant-semantic",
qdrant_host_type="cloud",
qdrant_url=os.getenv("QDRANT_URL"),
qdrant_api_base=os.getenv("QDRANT_URL"),
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
qdrant_collection_name="test_collection",
similarity_threshold=0.8,