Merge pull request #1829 from BerriAI/litellm_add_semantic_cache

[Feat] Add Semantic Caching to litellm💰
This commit is contained in:
Ishaan Jaff 2024-02-06 13:18:59 -08:00 committed by GitHub
commit 8a8f538329
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 569 additions and 16 deletions

View file

@ -11,3 +11,4 @@ boto3
orjson
pydantic
google-cloud-aiplatform
redisvl==0.0.7 # semantic caching

View file

@ -1,11 +1,11 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Caching - In-Memory, Redis, s3
# Caching - In-Memory, Redis, s3, Redis Semantic Cache
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
## Initialize Cache - In Memory, Redis, s3 Bucket
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache
<Tabs>
@ -18,7 +18,7 @@ pip install redis
```
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
### Quick Start
```python
import litellm
from litellm import completion
@ -55,7 +55,7 @@ Set AWS environment variables
AWS_ACCESS_KEY_ID = "AKI*******"
AWS_SECRET_ACCESS_KEY = "WOl*****"
```
### Quick Start
```python
import litellm
from litellm import completion
@ -80,6 +80,66 @@ response2 = completion(
</TabItem>
<TabItem value="redis-sem" label="redis-semantic cache">
Install redis
```shell
pip install redisvl==0.0.7
```
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
```python
import litellm
from litellm import completion
from litellm.caching import Cache
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
print("testing semantic caching")
litellm.cache = Cache(
type="redis-semantic",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
similarity_threshold=0.8, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
redis_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
)
response1 = completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
],
max_tokens=20,
)
print(f"response1: {response1}")
random_number = random.randint(1, 100000)
response2 = completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
],
max_tokens=20,
)
print(f"response2: {response1}")
assert response1.id == response2.id
# response1 == response2, response 1 is cached
```
</TabItem>
<TabItem value="in-mem" label="in memory cache">
### Quick Start

View file

@ -7,9 +7,10 @@ Cache LLM Responses
LiteLLM supports:
- In Memory Cache
- Redis Cache
- Redis Semantic Cache
- s3 Bucket Cache
## Quick Start - Redis, s3 Cache
## Quick Start - Redis, s3 Cache, Semantic Cache
<Tabs>
<TabItem value="redis" label="redis cache">
@ -84,6 +85,56 @@ litellm_settings:
$ litellm --config /path/to/config.yaml
```
</TabItem>
<TabItem value="redis-sem" label="redis semantic cache">
Caching can be enabled by adding the `cache` key in the `config.yaml`
### Step 1: Add `cache` to the config.yaml
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- model_name: azure-embedding-model
litellm_params:
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
litellm_settings:
set_verbose: True
cache: True # set cache responses to True, litellm defaults to using a redis cache
cache_params:
type: "redis-semantic"
similarity_threshold: 0.8 # similarity threshold for semantic cache
redis_semantic_cache_embedding_model: azure-embedding-model # set this to a model_name set in model_list
```
### Step 2: Add Redis Credentials to .env
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
```shell
REDIS_URL = "" # REDIS_URL='redis://username:password@hostname:port/database'
## OR ##
REDIS_HOST = "" # REDIS_HOST='redis-18841.c274.us-east-1-3.ec2.cloud.redislabs.com'
REDIS_PORT = "" # REDIS_PORT='18841'
REDIS_PASSWORD = "" # REDIS_PASSWORD='liteLlmIsAmazing'
```
**Additional kwargs**
You can pass in any additional redis.Redis arg, by storing the variable + value in your os environment, like this:
```shell
REDIS_<redis-kwarg-name> = ""
```
### Step 3: Run proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
</TabItem>
</Tabs>

View file

@ -83,7 +83,6 @@ class InMemoryCache(BaseCache):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
@ -217,7 +216,6 @@ class RedisCache(BaseCache):
def flush_cache(self):
self.redis_client.flushall()
async def disconnect(self):
pass
@ -225,6 +223,314 @@ class RedisCache(BaseCache):
self.redis_client.delete(key)
class RedisSemanticCache(BaseCache):
def __init__(
self,
host=None,
port=None,
password=None,
redis_url=None,
similarity_threshold=None,
use_async=False,
embedding_model="text-embedding-ada-002",
**kwargs,
):
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
print_verbose(
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
)
if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
schema = {
"index": {
"name": "litellm_semantic_cache_index",
"prefix": "litellm",
"storage_type": "hash",
},
"fields": {
"text": [{"name": "response"}],
"text": [{"name": "prompt"}],
"vector": [
{
"name": "litellm_embedding",
"dims": 1536,
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
}
],
},
}
if redis_url is None:
# if no url passed, check if host, port and password are passed, if not raise an Exception
if host is None or port is None or password is None:
# try checking env for host, port and password
import os
host = os.getenv("REDIS_HOST")
port = os.getenv("REDIS_PORT")
password = os.getenv("REDIS_PASSWORD")
if host is None or port is None or password is None:
raise Exception("Redis host, port, and password must be provided")
redis_url = "redis://:" + password + "@" + host + ":" + port
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async == False:
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url)
try:
self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
elif use_async == True:
schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True)
#
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
"""
if cached_response is None:
return cached_response
# check if cached_response is bytes
if isinstance(cached_response, bytes):
cached_response = cached_response.decode("utf-8")
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
def set_cache(self, key, value, **kwargs):
import numpy as np
print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)
new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# Add more data
keys = self.index.load(new_data)
return
def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np
# query
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# convert to embedding
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
num_results=1,
)
results = self.index.query(query)
if results == None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def async_set_cache(self, key, value, **kwargs):
import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
try:
await self.index.acreate(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)
new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# Add more data
keys = await self.index.aload(new_data)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
# query
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
)
results = await self.index.aquery(query)
if results == None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
class S3Cache(BaseCache):
def __init__(
@ -429,10 +735,11 @@ class DualCache(BaseCache):
class Cache:
def __init__(
self,
type: Optional[Literal["local", "redis", "s3"]] = "local",
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
@ -447,16 +754,20 @@ class Cache:
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
**kwargs,
):
"""
Initializes the cache based on the given type.
Args:
type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic"
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache
@ -468,6 +779,16 @@ class Cache:
"""
if type == "redis":
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
elif type == "redis-semantic":
self.cache = RedisSemanticCache(
host,
port,
password,
similarity_threshold=similarity_threshold,
use_async=redis_semantic_cache_use_async,
embedding_model=redis_semantic_cache_embedding_model,
**kwargs,
)
elif type == "local":
self.cache = InMemoryCache()
elif type == "s3":
@ -647,6 +968,7 @@ class Cache:
The cached result if it exists, otherwise None.
"""
try: # never block execution
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
@ -656,7 +978,7 @@ class Cache:
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.cache.get_cache(cache_key)
cached_result = self.cache.get_cache(cache_key, messages=messages)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
@ -671,6 +993,7 @@ class Cache:
Used for embedding calls in async wrapper
"""
try: # never block execution
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
@ -680,7 +1003,9 @@ class Cache:
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key)
cached_result = await self.cache.async_get_cache(
cache_key, *args, **kwargs
)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)

View file

@ -73,10 +73,14 @@ litellm_settings:
max_budget: 1.5000
models: ["azure-gpt-3.5"]
duration: None
cache: True # set cache responses to True
cache_params:
type: "redis-semantic"
similarity_threshold: 0.8
redis_semantic_cache_embedding_model: azure-embedding-model
upperbound_key_generate_params:
max_budget: 100
duration: "30d"
# cache: True
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]

View file

@ -1168,7 +1168,7 @@ class ProxyConfig:
verbose_proxy_logger.debug(f"passed cache type={cache_type}")
if cache_type == "redis":
if cache_type == "redis" or cache_type == "redis-semantic":
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
@ -1195,6 +1195,9 @@ class ProxyConfig:
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
)
print() # noqa
if cache_type == "redis-semantic":
# by default this should always be async
cache_params.update({"redis_semantic_cache_use_async": True})
# users can pass os.environ/ variables on the proxy - we should read them from the env
for key, value in cache_params.items():

View file

@ -987,3 +987,102 @@ def test_cache_context_managers():
# test_cache_context_managers()
@pytest.mark.skip(reason="beta test - new redis semantic cache")
def test_redis_semantic_cache_completion():
litellm.set_verbose = True
import logging
logging.basicConfig(level=logging.DEBUG)
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
print("testing semantic caching")
litellm.cache = Cache(
type="redis-semantic",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
similarity_threshold=0.8,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
)
response1 = completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
],
max_tokens=20,
)
print(f"response1: {response1}")
random_number = random.randint(1, 100000)
response2 = completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
],
max_tokens=20,
)
print(f"response2: {response1}")
assert response1.id == response2.id
# test_redis_cache_completion()
@pytest.mark.skip(reason="beta test - new redis semantic cache")
@pytest.mark.asyncio
async def test_redis_semantic_cache_acompletion():
litellm.set_verbose = True
import logging
logging.basicConfig(level=logging.DEBUG)
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
print("testing semantic caching")
litellm.cache = Cache(
type="redis-semantic",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
similarity_threshold=0.8,
redis_semantic_cache_use_async=True,
)
response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
],
max_tokens=5,
)
print(f"response1: {response1}")
random_number = random.randint(1, 100000)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
],
max_tokens=5,
)
print(f"response2: {response2}")
assert response1.id == response2.id

View file

@ -55,7 +55,7 @@ from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache
from .caching import S3Cache, RedisSemanticCache
from .exceptions import (
AuthenticationError,
BadRequestError,
@ -2533,6 +2533,14 @@ def client(original_function):
):
if len(cached_result) == 1 and cached_result[0] is None:
cached_result = None
elif isinstance(litellm.cache.cache, RedisSemanticCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[
"preset_cache_key"
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
cached_result = await litellm.cache.async_get_cache(
*args, **kwargs
)
else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[

View file

@ -9,6 +9,8 @@ uvicorn==0.22.0 # server dep
gunicorn==21.2.0 # server dep
boto3==1.28.58 # aws bedrock/sagemaker calls
redis==4.6.0 # caching
redisvl==0.0.7 # semantic caching
numpy==1.24.3 # semantic caching
prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions
google-generativeai==0.3.2 # for vertex ai calls