From 851db5eceafb159d0f6f97a2b10277b09715ee71 Mon Sep 17 00:00:00 2001 From: Haadi Rakhangi <127193364+haadirakhangi@users.noreply.github.com> Date: Fri, 2 Aug 2024 21:07:19 +0530 Subject: [PATCH] qdrant semantic caching added --- docs/my-website/docs/caching/all_caches.md | 70 +++- litellm/caching.py | 374 ++++++++++++++++++++- litellm/utils.py | 10 +- 3 files changed, 449 insertions(+), 5 deletions(-) diff --git a/docs/my-website/docs/caching/all_caches.md b/docs/my-website/docs/caching/all_caches.md index 1b8bbd8e0..c9cd2fc78 100644 --- a/docs/my-website/docs/caching/all_caches.md +++ b/docs/my-website/docs/caching/all_caches.md @@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht ::: -## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache +## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache, Qdrant Semantic @@ -144,7 +144,67 @@ assert response1.id == response2.id + +Install redis +```shell +pip install qdrant-client +``` + +You can set up your own cloud Qdrant cluster by following this: https://qdrant.tech/documentation/quickstart-cloud/ + +To set up a Qdrant cluster locally follow: https://qdrant.tech/documentation/quickstart/ +```python +import litellm +from litellm import completion +from litellm.caching import Cache + +random_number = random.randint( + 1, 100000 +) # add a random number to ensure it's always adding / reading from cache + +print("testing semantic caching") +litellm.cache = Cache( + type="qdrant-semantic", + qdrant_url=os.environ["QDRANT_URL"], + qdrant_username=os.environ["QDRANT_USERNAME"]", + qdrant_password=os.environ["QDRANT_PASSWORD"], + qdrant_collection_name="your_collection_name", # any name of your collection + similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity + qdrant_quantization_config = "binary", # can be one of 'binary', 'product' or 'scalar' quantizations that is supported by qdrant + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here +) + +response1 = completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ], + max_tokens=20, +) +print(f"response1: {response1}") + +random_number = random.randint(1, 100000) + +response2 = completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ], + max_tokens=20, +) +print(f"response2: {response1}") +assert response1.id == response2.id +# response1 == response2, response 1 is cached +``` + + @@ -435,6 +495,14 @@ def __init__( # disk cache params disk_cache_dir=None, + # qdrant cache params + qdrant_username: Optional[str] = None, + qdrant_password: Optional[str] = None, + qdrant_url: Optional[str] = None, + qdrant_collection_name: Optional[str] = None, + qdrant_quantization_config: Optional[str] = None, + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", + **kwargs ): ``` diff --git a/litellm/caching.py b/litellm/caching.py index fa10095da..06b15714e 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1217,6 +1217,354 @@ class RedisSemanticCache(BaseCache): async def _index_info(self): return await self.index.ainfo() +class QdrantSemanticCache(BaseCache): + def __init__( + self, + qdrant_username=None, + qdrant_password=None, + qdrant_url=None, + collection_name=None, + similarity_threshold=None, + quantization_config=None, + embedding_model="text-embedding-ada-002" + ): + from qdrant_client import models, AsyncQdrantClient, QdrantClient + import base64 + + if collection_name is None: + raise Exception("collection_name must be provided, passed None") + + self.collection_name = collection_name + print_verbose( + f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}" + ) + + if similarity_threshold is None: + raise Exception("similarity_threshold must be provided, passed None") + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + + if qdrant_url is None or qdrant_username is None or qdrant_password is None: + import os + + qdrant_url = os.getenv('QDRANT_URL') + qdrant_username = os.getenv('QDRANT_USERNAME') + qdrant_password = os.getenv('QDRANT PASSWORD') + if qdrant_url is None or qdrant_username is None or qdrant_password is None: + raise Exception("Qdrant url, username and password must be provided") + + print_verbose(f"qdrant semantic-cache qdrant_url: {qdrant_url}") + self.credentials = f"{qdrant_username}:{qdrant_password}" + self.encoded_credentials = base64.b64encode(self.credentials.encode()).decode() + self.headers = { + "Authorization": f"Basic {self.encoded_credentials}" + } + + self.qdrant_client = QdrantClient( + url= qdrant_url, + timeout=1200, + headers=self.headers + ) + + self.qdrant_client_async = AsyncQdrantClient( + url= qdrant_url, + timeout=1200, + headers=self.headers + ) + if quantization_config is None: + print('Quantization config is not provided. Default binary quantization will be used.') + + if self.qdrant_client.collection_exists(collection_name=f"{self.collection_name}"): + self.collection_info = self.qdrant_client.get_collection(f"{self.collection_name}") + print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}') + else: + if quantization_config is None or quantization_config == 'binary': + quantization_params = models.BinaryQuantization( + binary= models.BinaryQuantizationConfig(always_ram=False), + ) + elif quantization_config == 'scalar': + quantization_params = models.ScalarQuantization( + scalar=models.ScalarQuantizationConfig( + type=models.ScalarType.INT8, + quantile=0.99, + always_ram=False, + ), + ) + elif quantization_config == 'product': + quantization_params = models.ProductQuantization( + product=models.ProductQuantizationConfig( + compression=models.CompressionRatio.X16, + always_ram=False, + ), + ) + else: + raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'") + + self.qdrant_client.create_collection( + collection_name=f"{self.collection_name}", + vectors_config= models.VectorParams( + size=1536, + distance= models.Distance.COSINE + ), + quantization_config= quantization_params + ) + + self.collection_info = self.qdrant_client.get_collection(f"{self.collection_name}") + print_verbose(f'New collection created.\nCollection details:{self.collection_info}') + + def _get_cache_logic(self, cached_response: Any): + if cached_response is None: + return cached_response + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) + return cached_response + + def set_cache(self, key, value, **kwargs): + print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}") + from qdrant_client import models + import uuid + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # create an embedding for prompt + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + value = str(value) + assert isinstance(value, str) + + keys = self.qdrant_client.upsert( + collection_name=f"{self.collection_name}", + points=[ + models.PointStruct( + id=str(uuid.uuid4()), + payload={ + "text": prompt, + "response": value, + }, + vector= embedding, + ), + ] + ) + return + + def get_cache(self, key, **kwargs): + print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}") + from qdrant_client import models + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # convert to embedding + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + results = self.qdrant_client.search( + collection_name=self.collection_name, + query_vector= embedding, + search_params= models.SearchParams( + quantization= models.QuantizationSearchParams( + ignore=False, + rescore=True, + oversampling=3.0, + ), + exact=False, + ), + limit=1 + ) + + if results == None: + return None + if isinstance(results, list): + if len(results) == 0: + return None + + similarity = results[0].score + cached_prompt = results[0].payload['text'] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + if similarity >= self.similarity_threshold: + # cache hit ! + cached_value = results[0].payload['response'] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def async_set_cache(self, key, value, **kwargs): + from litellm.proxy.proxy_server import llm_router, llm_model_list + from qdrant_client import models + import uuid + print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + # create an embedding for prompt + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + value = str(value) + assert isinstance(value, str) + + keys = await self.qdrant_client_async.upsert( + collection_name=f"{self.collection_name}", + points=[ + models.PointStruct( + id=str(uuid.uuid4()), + payload={ + "text": prompt, + "response": value, + }, + vector= embedding, + ), + ] + ) + return + + async def async_get_cache(self, key, **kwargs): + print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") + from qdrant_client import models + from litellm.proxy.proxy_server import llm_router, llm_model_list + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + results = await self.qdrant_client_async.search( + collection_name=self.collection_name, + query_vector= embedding, + search_params= models.SearchParams( + quantization= models.QuantizationSearchParams( + ignore=False, + rescore=True, + oversampling=3.0, + ), + exact=False, + ), + limit=1 + ) + + if results == None: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + if isinstance(results, list): + if len(results) == 0: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + similarity = results[0].score + cached_prompt = results[0].payload['text'] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + + if similarity >= self.similarity_threshold: + # cache hit ! + cached_value = results[0].payload['response'] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def _collection_info(self): + return self.collection_info class S3Cache(BaseCache): def __init__( @@ -1673,7 +2021,7 @@ class Cache: def __init__( self, type: Optional[ - Literal["local", "redis", "redis-semantic", "s3", "disk"] + Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"] ] = "local", host: Optional[str] = None, port: Optional[str] = None, @@ -1722,17 +2070,27 @@ class Cache: redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_flush_size=None, disk_cache_dir=None, + qdrant_username: Optional[str] = None, + qdrant_password: Optional[str] = None, + qdrant_url: Optional[str] = None, + qdrant_collection_name: Optional[str] = None, + qdrant_quantization_config: Optional[str] = None, + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", **kwargs, ): """ Initializes the cache based on the given type. Args: - type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local". + type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". - similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" + qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic" + qdrant_username (str, optional): The username for the qdrant cluster. Required if type is "qdrant-semantic" + qdrant_password (str, optional): The password for the qdrant cluster. Required if type is "qdrant-semantic" + qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic" + similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic" supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. **kwargs: Additional keyword arguments for redis.Redis() cache @@ -1757,6 +2115,16 @@ class Cache: embedding_model=redis_semantic_cache_embedding_model, **kwargs, ) + elif type == "qdrant-semantic": + self.cache = QdrantSemanticCache( + qdrant_username= qdrant_username, + qdrant_password= qdrant_password, + qdrant_url= qdrant_url, + collection_name= qdrant_collection_name, + similarity_threshold= similarity_threshold, + quantization_config= qdrant_quantization_config, + embedding_model= qdrant_semantic_cache_embedding_model, + ) elif type == "local": self.cache = InMemoryCache() elif type == "s3": diff --git a/litellm/utils.py b/litellm/utils.py index 84b15cb19..1475639ea 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -113,7 +113,7 @@ import importlib.metadata from openai import OpenAIError as OriginalError from ._logging import verbose_logger -from .caching import RedisCache, RedisSemanticCache, S3Cache +from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache from .exceptions import ( APIConnectionError, APIError, @@ -1114,6 +1114,14 @@ def client(original_function): cached_result = await litellm.cache.async_get_cache( *args, **kwargs ) + elif isinstance(litellm.cache.cache, QdrantSemanticCache): + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = await litellm.cache.async_get_cache( + *args, **kwargs + ) else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) kwargs["preset_cache_key"] = (