diff --git a/docs/my-website/docs/caching/all_caches.md b/docs/my-website/docs/caching/all_caches.md index 1b8bbd8e09..c46f6d22cf 100644 --- a/docs/my-website/docs/caching/all_caches.md +++ b/docs/my-website/docs/caching/all_caches.md @@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht ::: -## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache +## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache, Qdrant Semantic @@ -144,7 +144,62 @@ assert response1.id == response2.id + +You can set up your own cloud Qdrant cluster by following this: https://qdrant.tech/documentation/quickstart-cloud/ + +To set up a Qdrant cluster locally follow: https://qdrant.tech/documentation/quickstart/ +```python +import litellm +from litellm import completion +from litellm.caching import Cache + +random_number = random.randint( + 1, 100000 +) # add a random number to ensure it's always adding / reading from cache + +print("testing semantic caching") +litellm.cache = Cache( + type="qdrant-semantic", + qdrant_host_type="cloud", # can be either 'cloud' or 'local' + qdrant_url=os.environ["QDRANT_URL"], + qdrant_api_key=os.environ["QDRANT_API_KEY"], + qdrant_collection_name="your_collection_name", # any name of your collection + similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity + qdrant_quantization_config ="binary", # can be one of 'binary', 'product' or 'scalar' quantizations that is supported by qdrant + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here +) + +response1 = completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ], + max_tokens=20, +) +print(f"response1: {response1}") + +random_number = random.randint(1, 100000) + +response2 = completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ], + max_tokens=20, +) +print(f"response2: {response1}") +assert response1.id == response2.id +# response1 == response2, response 1 is cached +``` + + @@ -435,6 +490,14 @@ def __init__( # disk cache params disk_cache_dir=None, + # qdrant cache params + qdrant_url: Optional[str] = None, + qdrant_api_key: Optional[str] = None, + qdrant_collection_name: Optional[str] = None, + qdrant_quantization_config: Optional[str] = None, + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", + qdrant_host_type: Optional[Literal["local","cloud"]] = "local", + **kwargs ): ``` diff --git a/litellm/caching.py b/litellm/caching.py index e37811b773..0615e533f8 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1219,6 +1219,410 @@ class RedisSemanticCache(BaseCache): async def _index_info(self): return await self.index.ainfo() +class QdrantSemanticCache(BaseCache): + def __init__( + self, + qdrant_url=None, + qdrant_api_key = None, + collection_name=None, + similarity_threshold=None, + quantization_config=None, + embedding_model="text-embedding-ada-002", + host_type = None + ): + from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + _get_async_httpx_client + ) + + if collection_name is None: + raise Exception("collection_name must be provided, passed None") + + self.collection_name = collection_name + print_verbose( + f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}" + ) + + if similarity_threshold is None: + raise Exception("similarity_threshold must be provided, passed None") + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + + if host_type=="cloud": + import os + if qdrant_url is None: + qdrant_url = os.getenv('QDRANT_URL') + if qdrant_api_key is None: + qdrant_api_key = os.getenv('QDRANT_API_KEY') + if qdrant_url is not None and qdrant_api_key is not None: + headers = { + "api-key": qdrant_api_key, + "Content-Type": "application/json" + } + else: + raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting") + elif host_type=="local": + import os + if qdrant_url is None: + qdrant_url = os.getenv('QDRANT_URL') + if qdrant_url is None: + raise Exception("Qdrant url must be provided for qdrant local hosting") + if qdrant_api_key is None: + qdrant_api_key = os.getenv('QDRANT_API_KEY') + if qdrant_api_key is None: + print_verbose('Running locally without API Key.') + headers= { + "Content-Type": "application/json" + } + else: + print_verbose("Running locally with API Key") + headers = { + "api-key": qdrant_api_key, + "Content-Type": "application/json" + } + else: + raise Exception("Host type can be either 'local' or 'cloud'") + + self.qdrant_url = qdrant_url + self.qdrant_api_key = qdrant_api_key + print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}") + + self.headers = headers + + self.sync_client = _get_httpx_client() + self.async_client = _get_async_httpx_client() + + if quantization_config is None: + print('Quantization config is not provided. Default binary quantization will be used.') + + collection_exists = self.sync_client.get( + url= f"{self.qdrant_url}/collections/{self.collection_name}/exists", + headers=self.headers + ) + if collection_exists.json()['result']['exists']: + collection_details = self.sync_client.get( + url=f"{self.qdrant_url}/collections/{self.collection_name}", + headers=self.headers + ) + self.collection_info = collection_details.json() + print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}') + else: + if quantization_config is None or quantization_config == 'binary': + quantization_params = { + "binary": { + "always_ram": False, + } + } + elif quantization_config == 'scalar': + quantization_params = { + "scalar": { + "type": "int8", + "quantile": 0.99, + "always_ram": False + } + } + elif quantization_config == 'product': + quantization_params = { + "product": { + "compression": "x16", + "always_ram": False + } + } + else: + raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'") + + new_collection_status = self.sync_client.put( + url=f"{self.qdrant_url}/collections/{self.collection_name}", + json={ + "vectors": { + "size": 1536, + "distance": "Cosine" + }, + "quantization_config": quantization_params + }, + headers=self.headers + ) + if new_collection_status.json()["result"]: + collection_details = self.sync_client.get( + url=f"{self.qdrant_url}/collections/{self.collection_name}", + headers=self.headers + ) + self.collection_info = collection_details.json() + print_verbose(f'New collection created.\nCollection details:{self.collection_info}') + else: + raise Exception("Error while creating new collection") + + def _get_cache_logic(self, cached_response: Any): + if cached_response is None: + return cached_response + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) + return cached_response + + def set_cache(self, key, value, **kwargs): + print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}") + import uuid + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # create an embedding for prompt + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + value = str(value) + assert isinstance(value, str) + + data = { + "points": [ + { + "id": str(uuid.uuid4()), + "vector": embedding, + "payload": { + "text": prompt, + "response": value, + } + }, + ] + } + keys = self.sync_client.put( + url=f"{self.qdrant_url}/collections/{self.collection_name}/points", + headers=self.headers, + json=data + ) + return + + def get_cache(self, key, **kwargs): + print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}") + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # convert to embedding + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + data = { + "vector": embedding, + "params": { + "quantization": { + "ignore": False, + "rescore": True, + "oversampling": 3.0, + } + }, + "limit":1, + "with_payload": True + } + + search_response = self.sync_client.post( + url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search", + headers=self.headers, + json=data + ) + results = search_response.json()["result"] + + if results == None: + return None + if isinstance(results, list): + if len(results) == 0: + return None + + similarity = results[0]["score"] + cached_prompt = results[0]["payload"]["text"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + if similarity >= self.similarity_threshold: + # cache hit ! + cached_value = results[0]["payload"]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def async_set_cache(self, key, value, **kwargs): + from litellm.proxy.proxy_server import llm_router, llm_model_list + import uuid + print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + # create an embedding for prompt + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + value = str(value) + assert isinstance(value, str) + + data = { + "points": [ + { + "id": str(uuid.uuid4()), + "vector": embedding, + "payload": { + "text": prompt, + "response": value, + } + }, + ] + } + + keys = await self.async_client.put( + url=f"{self.qdrant_url}/collections/{self.collection_name}/points", + headers=self.headers, + json=data + ) + return + + async def async_get_cache(self, key, **kwargs): + print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") + from litellm.proxy.proxy_server import llm_router, llm_model_list + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + data = { + "vector": embedding, + "params": { + "quantization": { + "ignore": False, + "rescore": True, + "oversampling": 3.0, + } + }, + "limit":1, + "with_payload": True + } + + search_response = await self.async_client.post( + url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search", + headers=self.headers, + json=data + ) + + results = search_response.json()["result"] + + if results == None: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + if isinstance(results, list): + if len(results) == 0: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + similarity = results[0]["score"] + cached_prompt = results[0]["payload"]["text"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + + if similarity >= self.similarity_threshold: + # cache hit ! + cached_value = results[0]["payload"]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def _collection_info(self): + return self.collection_info class S3Cache(BaseCache): def __init__( @@ -1676,7 +2080,7 @@ class Cache: def __init__( self, type: Optional[ - Literal["local", "redis", "redis-semantic", "s3", "disk"] + Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"] ] = "local", host: Optional[str] = None, port: Optional[str] = None, @@ -1725,17 +2129,27 @@ class Cache: redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_flush_size=None, disk_cache_dir=None, + qdrant_url: Optional[str] = None, + qdrant_api_key: Optional[str] = None, + qdrant_collection_name: Optional[str] = None, + qdrant_quantization_config: Optional[str] = None, + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", + qdrant_host_type: Optional[Literal["local","cloud"]] = "local", **kwargs, ): """ Initializes the cache based on the given type. Args: - type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local". + type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". - similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" + qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". + qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. Required if qdrant_host_type is "cloud" and optional if qdrant_host_type is "local". + qdrant_host_type (str, optional): Can be either "local" or "cloud". Should be "local" when you are running a local qdrant cluster or "cloud" when you are using a qdrant cloud cluster. + qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". + similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. **kwargs: Additional keyword arguments for redis.Redis() cache @@ -1760,6 +2174,16 @@ class Cache: embedding_model=redis_semantic_cache_embedding_model, **kwargs, ) + elif type == "qdrant-semantic": + self.cache = QdrantSemanticCache( + qdrant_url= qdrant_url, + qdrant_api_key= qdrant_api_key, + collection_name= qdrant_collection_name, + similarity_threshold= similarity_threshold, + quantization_config= qdrant_quantization_config, + embedding_model= qdrant_semantic_cache_embedding_model, + host_type=qdrant_host_type + ) elif type == "local": self.cache = InMemoryCache() elif type == "s3": diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 1828a92d2e..de517a086c 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -129,6 +129,62 @@ class AsyncHTTPHandler: except Exception as e: raise e + async def put( + self, + url: str, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + stream: bool = False, + ): + try: + if timeout is None: + timeout = self.timeout + req = self.client.build_request( + "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + response = await self.client.send(req, stream=stream) + response.raise_for_status() + return response + except (httpx.RemoteProtocolError, httpx.ConnectError): + # Retry the request with a new session if there is a connection error + new_client = self.create_client(timeout=timeout, concurrent_limit=1) + try: + return await self.single_connection_post_request( + url=url, + client=new_client, + data=data, + json=json, + params=params, + headers=headers, + stream=stream, + ) + finally: + await new_client.aclose() + except httpx.TimeoutException as e: + headers = {} + if hasattr(e, "response") and e.response is not None: + for key, value in e.response.headers.items(): + headers["response_headers-{}".format(key)] = value + + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + headers=headers, + ) + except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) + raise e + except Exception as e: + raise e + async def delete( self, url: str, @@ -274,6 +330,38 @@ class HTTPHandler: except Exception as e: raise e + def put( + self, + url: str, + data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + stream: bool = False, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ): + try: + + if timeout is not None: + req = self.client.build_request( + "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + else: + req = self.client.build_request( + "PUT", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) + return response + except httpx.TimeoutException: + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + ) + except Exception as e: + raise e + + def __del__(self) -> None: try: self.close() @@ -335,4 +423,4 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler: _new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client - return _new_client + return _new_client \ No newline at end of file diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 89b83bcd66..c201cd3a18 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -1732,3 +1732,108 @@ def test_caching_redis_simple(caplog, capsys): assert redis_async_caching_error is False assert redis_service_logging_error is False assert "async success_callback: reaches cache for logging" not in captured.out + +@pytest.mark.asyncio +async def test_qdrant_semantic_cache_acompletion(): + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding /reading from cache + + print("Testing Qdrant Semantic Caching with acompletion") + + litellm.cache = Cache( + type="qdrant-semantic", + qdrant_host_type="cloud", + qdrant_url=os.getenv("QDRANT_URL"), + qdrant_api_key=os.getenv("QDRANT_API_KEY"), + qdrant_collection_name='test_collection', + similarity_threshold=0.8, + qdrant_quantization_config="binary" + ) + + response1 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ], + max_tokens=20, + ) + print(f"Response1: {response1}") + + random_number = random.randint(1, 100000) + + response2 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ], + max_tokens=20, + ) + print(f"Response2: {response2}") + assert response1.id == response2.id + +@pytest.mark.asyncio +async def test_qdrant_semantic_cache_acompletion_stream(): + try: + random_word = generate_random_word() + messages = [ + { + "role": "user", + "content": f"write a joke about: {random_word}", + } + ] + litellm.cache = Cache( + type="qdrant-semantic", + qdrant_host_type="cloud", + qdrant_url=os.getenv("QDRANT_URL"), + qdrant_api_key=os.getenv("QDRANT_API_KEY"), + qdrant_collection_name='test_collection', + similarity_threshold=0.8, + qdrant_quantization_config="binary" + ) + print("Test Qdrant Semantic Caching with streaming + acompletion") + response_1_content = "" + response_2_content = "" + + response1 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response1: + response_1_id = chunk.id + response_1_content += chunk.choices[0].delta.content or "" + + time.sleep(2) + + response2 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response2: + response_2_id = chunk.id + response_2_content += chunk.choices[0].delta.content or "" + + print("\nResponse 1", response_1_content, "\nResponse 1 id", response_1_id) + print("\nResponse 2", response_2_content, "\nResponse 2 id", response_2_id) + assert ( + response_1_content == response_2_content + ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + assert (response_1_id == response_2_id), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}" + litellm.cache = None + litellm.success_callback = [] + litellm._async_success_callback = [] + except Exception as e: + print(f"{str(e)}\n\n{traceback.format_exc()}") + raise e diff --git a/litellm/utils.py b/litellm/utils.py index ff6d2fd76b..a6d48dd311 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -121,7 +121,7 @@ import importlib.metadata from openai import OpenAIError as OriginalError from ._logging import verbose_logger -from .caching import RedisCache, RedisSemanticCache, S3Cache +from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache from .exceptions import ( APIConnectionError, APIError, @@ -1164,6 +1164,14 @@ def client(original_function): cached_result = await litellm.cache.async_get_cache( *args, **kwargs ) + elif isinstance(litellm.cache.cache, QdrantSemanticCache): + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = await litellm.cache.async_get_cache( + *args, **kwargs + ) else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) kwargs["preset_cache_key"] = (