forked from phoenix/litellm-mirror
Merge pull request #1829 from BerriAI/litellm_add_semantic_cache
[Feat] Add Semantic Caching to litellm💰
This commit is contained in:
commit
8a8f538329
9 changed files with 569 additions and 16 deletions
|
@ -11,3 +11,4 @@ boto3
|
||||||
orjson
|
orjson
|
||||||
pydantic
|
pydantic
|
||||||
google-cloud-aiplatform
|
google-cloud-aiplatform
|
||||||
|
redisvl==0.0.7 # semantic caching
|
|
@ -1,11 +1,11 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Caching - In-Memory, Redis, s3
|
# Caching - In-Memory, Redis, s3, Redis Semantic Cache
|
||||||
|
|
||||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
|
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
|
||||||
|
|
||||||
## Initialize Cache - In Memory, Redis, s3 Bucket
|
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache
|
||||||
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
@ -18,7 +18,7 @@ pip install redis
|
||||||
```
|
```
|
||||||
|
|
||||||
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
|
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
|
||||||
### Quick Start
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -55,7 +55,7 @@ Set AWS environment variables
|
||||||
AWS_ACCESS_KEY_ID = "AKI*******"
|
AWS_ACCESS_KEY_ID = "AKI*******"
|
||||||
AWS_SECRET_ACCESS_KEY = "WOl*****"
|
AWS_SECRET_ACCESS_KEY = "WOl*****"
|
||||||
```
|
```
|
||||||
### Quick Start
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -80,6 +80,66 @@ response2 = completion(
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
<TabItem value="redis-sem" label="redis-semantic cache">
|
||||||
|
|
||||||
|
Install redis
|
||||||
|
```shell
|
||||||
|
pip install redisvl==0.0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
from litellm.caching import Cache
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
|
||||||
|
print("testing semantic caching")
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis-semantic",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
similarity_threshold=0.8, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
||||||
|
redis_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
|
||||||
|
)
|
||||||
|
response1 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
|
||||||
|
random_number = random.randint(1, 100000)
|
||||||
|
|
||||||
|
response2 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"response2: {response1}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
# response1 == response2, response 1 is cached
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<TabItem value="in-mem" label="in memory cache">
|
<TabItem value="in-mem" label="in memory cache">
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
|
@ -7,9 +7,10 @@ Cache LLM Responses
|
||||||
LiteLLM supports:
|
LiteLLM supports:
|
||||||
- In Memory Cache
|
- In Memory Cache
|
||||||
- Redis Cache
|
- Redis Cache
|
||||||
|
- Redis Semantic Cache
|
||||||
- s3 Bucket Cache
|
- s3 Bucket Cache
|
||||||
|
|
||||||
## Quick Start - Redis, s3 Cache
|
## Quick Start - Redis, s3 Cache, Semantic Cache
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
||||||
<TabItem value="redis" label="redis cache">
|
<TabItem value="redis" label="redis cache">
|
||||||
|
@ -84,6 +85,56 @@ litellm_settings:
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
<TabItem value="redis-sem" label="redis semantic cache">
|
||||||
|
|
||||||
|
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
||||||
|
|
||||||
|
### Step 1: Add `cache` to the config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
- model_name: azure-embedding-model
|
||||||
|
litellm_params:
|
||||||
|
model: azure/azure-embedding-model
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
set_verbose: True
|
||||||
|
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||||
|
cache_params:
|
||||||
|
type: "redis-semantic"
|
||||||
|
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
||||||
|
redis_semantic_cache_embedding_model: azure-embedding-model # set this to a model_name set in model_list
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Add Redis Credentials to .env
|
||||||
|
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
REDIS_URL = "" # REDIS_URL='redis://username:password@hostname:port/database'
|
||||||
|
## OR ##
|
||||||
|
REDIS_HOST = "" # REDIS_HOST='redis-18841.c274.us-east-1-3.ec2.cloud.redislabs.com'
|
||||||
|
REDIS_PORT = "" # REDIS_PORT='18841'
|
||||||
|
REDIS_PASSWORD = "" # REDIS_PASSWORD='liteLlmIsAmazing'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Additional kwargs**
|
||||||
|
You can pass in any additional redis.Redis arg, by storing the variable + value in your os environment, like this:
|
||||||
|
```shell
|
||||||
|
REDIS_<redis-kwarg-name> = ""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Run proxy with config
|
||||||
|
```shell
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,6 @@ class InMemoryCache(BaseCache):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -217,7 +216,6 @@ class RedisCache(BaseCache):
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.redis_client.flushall()
|
self.redis_client.flushall()
|
||||||
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -225,6 +223,314 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client.delete(key)
|
self.redis_client.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisSemanticCache(BaseCache):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host=None,
|
||||||
|
port=None,
|
||||||
|
password=None,
|
||||||
|
redis_url=None,
|
||||||
|
similarity_threshold=None,
|
||||||
|
use_async=False,
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
from redisvl.index import SearchIndex
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
|
||||||
|
print_verbose(
|
||||||
|
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
|
||||||
|
)
|
||||||
|
if similarity_threshold is None:
|
||||||
|
raise Exception("similarity_threshold must be provided, passed None")
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
schema = {
|
||||||
|
"index": {
|
||||||
|
"name": "litellm_semantic_cache_index",
|
||||||
|
"prefix": "litellm",
|
||||||
|
"storage_type": "hash",
|
||||||
|
},
|
||||||
|
"fields": {
|
||||||
|
"text": [{"name": "response"}],
|
||||||
|
"text": [{"name": "prompt"}],
|
||||||
|
"vector": [
|
||||||
|
{
|
||||||
|
"name": "litellm_embedding",
|
||||||
|
"dims": 1536,
|
||||||
|
"distance_metric": "cosine",
|
||||||
|
"algorithm": "flat",
|
||||||
|
"datatype": "float32",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if redis_url is None:
|
||||||
|
# if no url passed, check if host, port and password are passed, if not raise an Exception
|
||||||
|
if host is None or port is None or password is None:
|
||||||
|
# try checking env for host, port and password
|
||||||
|
import os
|
||||||
|
|
||||||
|
host = os.getenv("REDIS_HOST")
|
||||||
|
port = os.getenv("REDIS_PORT")
|
||||||
|
password = os.getenv("REDIS_PASSWORD")
|
||||||
|
if host is None or port is None or password is None:
|
||||||
|
raise Exception("Redis host, port, and password must be provided")
|
||||||
|
|
||||||
|
redis_url = "redis://:" + password + "@" + host + ":" + port
|
||||||
|
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
||||||
|
if use_async == False:
|
||||||
|
self.index = SearchIndex.from_dict(schema)
|
||||||
|
self.index.connect(redis_url=redis_url)
|
||||||
|
try:
|
||||||
|
self.index.create(overwrite=False) # don't overwrite existing index
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
||||||
|
elif use_async == True:
|
||||||
|
schema["index"]["name"] = "litellm_semantic_cache_index_async"
|
||||||
|
self.index = SearchIndex.from_dict(schema)
|
||||||
|
self.index.connect(redis_url=redis_url, use_async=True)
|
||||||
|
|
||||||
|
#
|
||||||
|
def _get_cache_logic(self, cached_response: Any):
|
||||||
|
"""
|
||||||
|
Common 'get_cache_logic' across sync + async redis client implementations
|
||||||
|
"""
|
||||||
|
if cached_response is None:
|
||||||
|
return cached_response
|
||||||
|
|
||||||
|
# check if cached_response is bytes
|
||||||
|
if isinstance(cached_response, bytes):
|
||||||
|
cached_response = cached_response.decode("utf-8")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cached_response = json.loads(
|
||||||
|
cached_response
|
||||||
|
) # Convert string to dictionary
|
||||||
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response)
|
||||||
|
return cached_response
|
||||||
|
|
||||||
|
def set_cache(self, key, value, **kwargs):
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
|
# get the prompt
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# create an embedding for prompt
|
||||||
|
embedding_response = litellm.embedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
# make the embedding a numpy array, convert to bytes
|
||||||
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
||||||
|
value = str(value)
|
||||||
|
assert isinstance(value, str)
|
||||||
|
|
||||||
|
new_data = [
|
||||||
|
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add more data
|
||||||
|
keys = self.index.load(new_data)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_cache(self, key, **kwargs):
|
||||||
|
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# query
|
||||||
|
|
||||||
|
# get the messages
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = litellm.embedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
query = VectorQuery(
|
||||||
|
vector=embedding,
|
||||||
|
vector_field_name="litellm_embedding",
|
||||||
|
return_fields=["response", "prompt", "vector_distance"],
|
||||||
|
num_results=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.index.query(query)
|
||||||
|
if results == None:
|
||||||
|
return None
|
||||||
|
if isinstance(results, list):
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
vector_distance = results[0]["vector_distance"]
|
||||||
|
vector_distance = float(vector_distance)
|
||||||
|
similarity = 1 - vector_distance
|
||||||
|
cached_prompt = results[0]["prompt"]
|
||||||
|
|
||||||
|
# check similarity, if more than self.similarity_threshold, return results
|
||||||
|
print_verbose(
|
||||||
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
if similarity > self.similarity_threshold:
|
||||||
|
# cache hit !
|
||||||
|
cached_value = results[0]["response"]
|
||||||
|
print_verbose(
|
||||||
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
return self._get_cache_logic(cached_response=cached_value)
|
||||||
|
else:
|
||||||
|
# cache miss !
|
||||||
|
return None
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
import numpy as np
|
||||||
|
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
||||||
|
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
|
# get the prompt
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
# create an embedding for prompt
|
||||||
|
router_model_names = (
|
||||||
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
if llm_model_list is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if llm_router is not None and self.embedding_model in router_model_names:
|
||||||
|
embedding_response = await llm_router.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
# make the embedding a numpy array, convert to bytes
|
||||||
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
||||||
|
value = str(value)
|
||||||
|
assert isinstance(value, str)
|
||||||
|
|
||||||
|
new_data = [
|
||||||
|
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add more data
|
||||||
|
keys = await self.index.aload(new_data)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
import numpy as np
|
||||||
|
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||||
|
|
||||||
|
# query
|
||||||
|
|
||||||
|
# get the messages
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
router_model_names = (
|
||||||
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
if llm_model_list is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if llm_router is not None and self.embedding_model in router_model_names:
|
||||||
|
embedding_response = await llm_router.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the embedding
|
||||||
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
|
||||||
|
query = VectorQuery(
|
||||||
|
vector=embedding,
|
||||||
|
vector_field_name="litellm_embedding",
|
||||||
|
return_fields=["response", "prompt", "vector_distance"],
|
||||||
|
)
|
||||||
|
results = await self.index.aquery(query)
|
||||||
|
if results == None:
|
||||||
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||||
|
return None
|
||||||
|
if isinstance(results, list):
|
||||||
|
if len(results) == 0:
|
||||||
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||||
|
return None
|
||||||
|
|
||||||
|
vector_distance = results[0]["vector_distance"]
|
||||||
|
vector_distance = float(vector_distance)
|
||||||
|
similarity = 1 - vector_distance
|
||||||
|
cached_prompt = results[0]["prompt"]
|
||||||
|
|
||||||
|
# check similarity, if more than self.similarity_threshold, return results
|
||||||
|
print_verbose(
|
||||||
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||||
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||||
|
|
||||||
|
if similarity > self.similarity_threshold:
|
||||||
|
# cache hit !
|
||||||
|
cached_value = results[0]["response"]
|
||||||
|
print_verbose(
|
||||||
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||||
|
)
|
||||||
|
return self._get_cache_logic(cached_response=cached_value)
|
||||||
|
else:
|
||||||
|
# cache miss !
|
||||||
|
return None
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class S3Cache(BaseCache):
|
class S3Cache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -429,10 +735,11 @@ class DualCache(BaseCache):
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
port: Optional[str] = None,
|
port: Optional[str] = None,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
|
similarity_threshold: Optional[float] = None,
|
||||||
supported_call_types: Optional[
|
supported_call_types: Optional[
|
||||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||||
|
@ -447,16 +754,20 @@ class Cache:
|
||||||
s3_aws_secret_access_key: Optional[str] = None,
|
s3_aws_secret_access_key: Optional[str] = None,
|
||||||
s3_aws_session_token: Optional[str] = None,
|
s3_aws_session_token: Optional[str] = None,
|
||||||
s3_config: Optional[Any] = None,
|
s3_config: Optional[Any] = None,
|
||||||
|
redis_semantic_cache_use_async=False,
|
||||||
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the cache based on the given type.
|
Initializes the cache based on the given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
|
||||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||||
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic"
|
||||||
|
|
||||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||||
|
|
||||||
|
@ -468,6 +779,16 @@ class Cache:
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||||
|
elif type == "redis-semantic":
|
||||||
|
self.cache = RedisSemanticCache(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
password,
|
||||||
|
similarity_threshold=similarity_threshold,
|
||||||
|
use_async=redis_semantic_cache_use_async,
|
||||||
|
embedding_model=redis_semantic_cache_embedding_model,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
elif type == "s3":
|
elif type == "s3":
|
||||||
|
@ -647,6 +968,7 @@ class Cache:
|
||||||
The cached result if it exists, otherwise None.
|
The cached result if it exists, otherwise None.
|
||||||
"""
|
"""
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
cache_key = kwargs["cache_key"]
|
cache_key = kwargs["cache_key"]
|
||||||
else:
|
else:
|
||||||
|
@ -656,7 +978,7 @@ class Cache:
|
||||||
max_age = cache_control_args.get(
|
max_age = cache_control_args.get(
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
)
|
)
|
||||||
cached_result = self.cache.get_cache(cache_key)
|
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||||
return self._get_cache_logic(
|
return self._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
|
@ -671,6 +993,7 @@ class Cache:
|
||||||
Used for embedding calls in async wrapper
|
Used for embedding calls in async wrapper
|
||||||
"""
|
"""
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
cache_key = kwargs["cache_key"]
|
cache_key = kwargs["cache_key"]
|
||||||
else:
|
else:
|
||||||
|
@ -680,7 +1003,9 @@ class Cache:
|
||||||
max_age = cache_control_args.get(
|
max_age = cache_control_args.get(
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
)
|
)
|
||||||
cached_result = await self.cache.async_get_cache(cache_key)
|
cached_result = await self.cache.async_get_cache(
|
||||||
|
cache_key, *args, **kwargs
|
||||||
|
)
|
||||||
return self._get_cache_logic(
|
return self._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
|
|
|
@ -73,10 +73,14 @@ litellm_settings:
|
||||||
max_budget: 1.5000
|
max_budget: 1.5000
|
||||||
models: ["azure-gpt-3.5"]
|
models: ["azure-gpt-3.5"]
|
||||||
duration: None
|
duration: None
|
||||||
|
cache: True # set cache responses to True
|
||||||
|
cache_params:
|
||||||
|
type: "redis-semantic"
|
||||||
|
similarity_threshold: 0.8
|
||||||
|
redis_semantic_cache_embedding_model: azure-embedding-model
|
||||||
upperbound_key_generate_params:
|
upperbound_key_generate_params:
|
||||||
max_budget: 100
|
max_budget: 100
|
||||||
duration: "30d"
|
duration: "30d"
|
||||||
# cache: True
|
|
||||||
# setting callback class
|
# setting callback class
|
||||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
|
||||||
|
|
|
@ -1168,7 +1168,7 @@ class ProxyConfig:
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"passed cache type={cache_type}")
|
verbose_proxy_logger.debug(f"passed cache type={cache_type}")
|
||||||
|
|
||||||
if cache_type == "redis":
|
if cache_type == "redis" or cache_type == "redis-semantic":
|
||||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||||
|
@ -1195,6 +1195,9 @@ class ProxyConfig:
|
||||||
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
||||||
)
|
)
|
||||||
print() # noqa
|
print() # noqa
|
||||||
|
if cache_type == "redis-semantic":
|
||||||
|
# by default this should always be async
|
||||||
|
cache_params.update({"redis_semantic_cache_use_async": True})
|
||||||
|
|
||||||
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
||||||
for key, value in cache_params.items():
|
for key, value in cache_params.items():
|
||||||
|
|
|
@ -987,3 +987,102 @@ def test_cache_context_managers():
|
||||||
|
|
||||||
|
|
||||||
# test_cache_context_managers()
|
# test_cache_context_managers()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="beta test - new redis semantic cache")
|
||||||
|
def test_redis_semantic_cache_completion():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
|
||||||
|
print("testing semantic caching")
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis-semantic",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
similarity_threshold=0.8,
|
||||||
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
response1 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
|
||||||
|
random_number = random.randint(1, 100000)
|
||||||
|
|
||||||
|
response2 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
print(f"response2: {response1}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
|
||||||
|
|
||||||
|
# test_redis_cache_completion()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="beta test - new redis semantic cache")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_semantic_cache_acompletion():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
|
||||||
|
print("testing semantic caching")
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis-semantic",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
similarity_threshold=0.8,
|
||||||
|
redis_semantic_cache_use_async=True,
|
||||||
|
)
|
||||||
|
response1 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
|
||||||
|
random_number = random.randint(1, 100000)
|
||||||
|
response2 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
|
|
@ -55,7 +55,7 @@ from .integrations.litedebugger import LiteDebugger
|
||||||
from .proxy._types import KeyManagementSystem
|
from .proxy._types import KeyManagementSystem
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
from .caching import S3Cache
|
from .caching import S3Cache, RedisSemanticCache
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
|
@ -2533,6 +2533,14 @@ def client(original_function):
|
||||||
):
|
):
|
||||||
if len(cached_result) == 1 and cached_result[0] is None:
|
if len(cached_result) == 1 and cached_result[0] is None:
|
||||||
cached_result = None
|
cached_result = None
|
||||||
|
elif isinstance(litellm.cache.cache, RedisSemanticCache):
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
|
kwargs[
|
||||||
|
"preset_cache_key"
|
||||||
|
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||||
|
cached_result = await litellm.cache.async_get_cache(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
kwargs[
|
kwargs[
|
||||||
|
|
|
@ -9,6 +9,8 @@ uvicorn==0.22.0 # server dep
|
||||||
gunicorn==21.2.0 # server dep
|
gunicorn==21.2.0 # server dep
|
||||||
boto3==1.28.58 # aws bedrock/sagemaker calls
|
boto3==1.28.58 # aws bedrock/sagemaker calls
|
||||||
redis==4.6.0 # caching
|
redis==4.6.0 # caching
|
||||||
|
redisvl==0.0.7 # semantic caching
|
||||||
|
numpy==1.24.3 # semantic caching
|
||||||
prisma==0.11.0 # for db
|
prisma==0.11.0 # for db
|
||||||
mangum==0.17.0 # for aws lambda functions
|
mangum==0.17.0 # for aws lambda functions
|
||||||
google-generativeai==0.3.2 # for vertex ai calls
|
google-generativeai==0.3.2 # for vertex ai calls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue