forked from phoenix/litellm-mirror
Merge pull request #1829 from BerriAI/litellm_add_semantic_cache
[Feat] Add Semantic Caching to litellm💰
This commit is contained in:
commit
8a8f538329
9 changed files with 569 additions and 16 deletions
|
@ -10,4 +10,5 @@ anthropic
|
|||
boto3
|
||||
orjson
|
||||
pydantic
|
||||
google-cloud-aiplatform
|
||||
google-cloud-aiplatform
|
||||
redisvl==0.0.7 # semantic caching
|
|
@ -1,11 +1,11 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Caching - In-Memory, Redis, s3
|
||||
# Caching - In-Memory, Redis, s3, Redis Semantic Cache
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
|
||||
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
@ -18,7 +18,7 @@ pip install redis
|
|||
```
|
||||
|
||||
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
@ -55,7 +55,7 @@ Set AWS environment variables
|
|||
AWS_ACCESS_KEY_ID = "AKI*******"
|
||||
AWS_SECRET_ACCESS_KEY = "WOl*****"
|
||||
```
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
@ -80,6 +80,66 @@ response2 = completion(
|
|||
</TabItem>
|
||||
|
||||
|
||||
<TabItem value="redis-sem" label="redis-semantic cache">
|
||||
|
||||
Install redis
|
||||
```shell
|
||||
pip install redisvl==0.0.7
|
||||
```
|
||||
|
||||
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm.caching import Cache
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
|
||||
print("testing semantic caching")
|
||||
litellm.cache = Cache(
|
||||
type="redis-semantic",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
similarity_threshold=0.8, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
|
||||
)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
|
||||
random_number = random.randint(1, 100000)
|
||||
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"response2: {response1}")
|
||||
assert response1.id == response2.id
|
||||
# response1 == response2, response 1 is cached
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
|
||||
|
||||
<TabItem value="in-mem" label="in memory cache">
|
||||
|
||||
### Quick Start
|
||||
|
|
|
@ -7,9 +7,10 @@ Cache LLM Responses
|
|||
LiteLLM supports:
|
||||
- In Memory Cache
|
||||
- Redis Cache
|
||||
- Redis Semantic Cache
|
||||
- s3 Bucket Cache
|
||||
|
||||
## Quick Start - Redis, s3 Cache
|
||||
## Quick Start - Redis, s3 Cache, Semantic Cache
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="redis" label="redis cache">
|
||||
|
@ -84,6 +85,56 @@ litellm_settings:
|
|||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
|
||||
<TabItem value="redis-sem" label="redis semantic cache">
|
||||
|
||||
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
||||
|
||||
### Step 1: Add `cache` to the config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: azure-embedding-model
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
cache_params:
|
||||
type: "redis-semantic"
|
||||
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
||||
redis_semantic_cache_embedding_model: azure-embedding-model # set this to a model_name set in model_list
|
||||
```
|
||||
|
||||
### Step 2: Add Redis Credentials to .env
|
||||
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
||||
|
||||
```shell
|
||||
REDIS_URL = "" # REDIS_URL='redis://username:password@hostname:port/database'
|
||||
## OR ##
|
||||
REDIS_HOST = "" # REDIS_HOST='redis-18841.c274.us-east-1-3.ec2.cloud.redislabs.com'
|
||||
REDIS_PORT = "" # REDIS_PORT='18841'
|
||||
REDIS_PASSWORD = "" # REDIS_PASSWORD='liteLlmIsAmazing'
|
||||
```
|
||||
|
||||
**Additional kwargs**
|
||||
You can pass in any additional redis.Redis arg, by storing the variable + value in your os environment, like this:
|
||||
```shell
|
||||
REDIS_<redis-kwarg-name> = ""
|
||||
```
|
||||
|
||||
### Step 3: Run proxy with config
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
|
|
@ -83,7 +83,6 @@ class InMemoryCache(BaseCache):
|
|||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
@ -217,7 +216,6 @@ class RedisCache(BaseCache):
|
|||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
@ -225,6 +223,314 @@ class RedisCache(BaseCache):
|
|||
self.redis_client.delete(key)
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
host=None,
|
||||
port=None,
|
||||
password=None,
|
||||
redis_url=None,
|
||||
similarity_threshold=None,
|
||||
use_async=False,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
from redisvl.index import SearchIndex
|
||||
from redisvl.query import VectorQuery
|
||||
|
||||
print_verbose(
|
||||
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
|
||||
)
|
||||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
schema = {
|
||||
"index": {
|
||||
"name": "litellm_semantic_cache_index",
|
||||
"prefix": "litellm",
|
||||
"storage_type": "hash",
|
||||
},
|
||||
"fields": {
|
||||
"text": [{"name": "response"}],
|
||||
"text": [{"name": "prompt"}],
|
||||
"vector": [
|
||||
{
|
||||
"name": "litellm_embedding",
|
||||
"dims": 1536,
|
||||
"distance_metric": "cosine",
|
||||
"algorithm": "flat",
|
||||
"datatype": "float32",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
if redis_url is None:
|
||||
# if no url passed, check if host, port and password are passed, if not raise an Exception
|
||||
if host is None or port is None or password is None:
|
||||
# try checking env for host, port and password
|
||||
import os
|
||||
|
||||
host = os.getenv("REDIS_HOST")
|
||||
port = os.getenv("REDIS_PORT")
|
||||
password = os.getenv("REDIS_PASSWORD")
|
||||
if host is None or port is None or password is None:
|
||||
raise Exception("Redis host, port, and password must be provided")
|
||||
|
||||
redis_url = "redis://:" + password + "@" + host + ":" + port
|
||||
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
||||
if use_async == False:
|
||||
self.index = SearchIndex.from_dict(schema)
|
||||
self.index.connect(redis_url=redis_url)
|
||||
try:
|
||||
self.index.create(overwrite=False) # don't overwrite existing index
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
||||
elif use_async == True:
|
||||
schema["index"]["name"] = "litellm_semantic_cache_index_async"
|
||||
self.index = SearchIndex.from_dict(schema)
|
||||
self.index.connect(redis_url=redis_url, use_async=True)
|
||||
|
||||
#
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
"""
|
||||
Common 'get_cache_logic' across sync + async redis client implementations
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
|
||||
# check if cached_response is bytes
|
||||
if isinstance(cached_response, bytes):
|
||||
cached_response = cached_response.decode("utf-8")
|
||||
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
import numpy as np
|
||||
|
||||
print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
# make the embedding a numpy array, convert to bytes
|
||||
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
new_data = [
|
||||
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
|
||||
]
|
||||
|
||||
# Add more data
|
||||
keys = self.index.load(new_data)
|
||||
|
||||
return
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from redisvl.query import VectorQuery
|
||||
import numpy as np
|
||||
|
||||
# query
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
query = VectorQuery(
|
||||
vector=embedding,
|
||||
vector_field_name="litellm_embedding",
|
||||
return_fields=["response", "prompt", "vector_distance"],
|
||||
num_results=1,
|
||||
)
|
||||
|
||||
results = self.index.query(query)
|
||||
if results == None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
vector_distance = results[0]["vector_distance"]
|
||||
vector_distance = float(vector_distance)
|
||||
similarity = 1 - vector_distance
|
||||
cached_prompt = results[0]["prompt"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
if similarity > self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
import numpy as np
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
|
||||
try:
|
||||
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
||||
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
# make the embedding a numpy array, convert to bytes
|
||||
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
new_data = [
|
||||
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
|
||||
]
|
||||
|
||||
# Add more data
|
||||
keys = await self.index.aload(new_data)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from redisvl.query import VectorQuery
|
||||
import numpy as np
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
|
||||
# query
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
query = VectorQuery(
|
||||
vector=embedding,
|
||||
vector_field_name="litellm_embedding",
|
||||
return_fields=["response", "prompt", "vector_distance"],
|
||||
)
|
||||
results = await self.index.aquery(query)
|
||||
if results == None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
vector_distance = results[0]["vector_distance"]
|
||||
vector_distance = float(vector_distance)
|
||||
similarity = 1 - vector_distance
|
||||
cached_prompt = results[0]["prompt"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity > self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
|
@ -429,10 +735,11 @@ class DualCache(BaseCache):
|
|||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
|
@ -447,16 +754,20 @@ class Cache:
|
|||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic"
|
||||
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
|
@ -468,6 +779,16 @@ class Cache:
|
|||
"""
|
||||
if type == "redis":
|
||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||
elif type == "redis-semantic":
|
||||
self.cache = RedisSemanticCache(
|
||||
host,
|
||||
port,
|
||||
password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
use_async=redis_semantic_cache_use_async,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
elif type == "s3":
|
||||
|
@ -647,6 +968,7 @@ class Cache:
|
|||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
|
@ -656,7 +978,7 @@ class Cache:
|
|||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
|
@ -671,6 +993,7 @@ class Cache:
|
|||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
try: # never block execution
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
|
@ -680,7 +1003,9 @@ class Cache:
|
|||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key)
|
||||
cached_result = await self.cache.async_get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
|
|
|
@ -73,10 +73,14 @@ litellm_settings:
|
|||
max_budget: 1.5000
|
||||
models: ["azure-gpt-3.5"]
|
||||
duration: None
|
||||
cache: True # set cache responses to True
|
||||
cache_params:
|
||||
type: "redis-semantic"
|
||||
similarity_threshold: 0.8
|
||||
redis_semantic_cache_embedding_model: azure-embedding-model
|
||||
upperbound_key_generate_params:
|
||||
max_budget: 100
|
||||
duration: "30d"
|
||||
# cache: True
|
||||
duration: "30d"
|
||||
# setting callback class
|
||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
|
||||
|
|
|
@ -1168,7 +1168,7 @@ class ProxyConfig:
|
|||
|
||||
verbose_proxy_logger.debug(f"passed cache type={cache_type}")
|
||||
|
||||
if cache_type == "redis":
|
||||
if cache_type == "redis" or cache_type == "redis-semantic":
|
||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
|
@ -1195,6 +1195,9 @@ class ProxyConfig:
|
|||
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
||||
)
|
||||
print() # noqa
|
||||
if cache_type == "redis-semantic":
|
||||
# by default this should always be async
|
||||
cache_params.update({"redis_semantic_cache_use_async": True})
|
||||
|
||||
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
||||
for key, value in cache_params.items():
|
||||
|
|
|
@ -987,3 +987,102 @@ def test_cache_context_managers():
|
|||
|
||||
|
||||
# test_cache_context_managers()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="beta test - new redis semantic cache")
|
||||
def test_redis_semantic_cache_completion():
|
||||
litellm.set_verbose = True
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
|
||||
print("testing semantic caching")
|
||||
litellm.cache = Cache(
|
||||
type="redis-semantic",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
similarity_threshold=0.8,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
|
||||
random_number = random.randint(1, 100000)
|
||||
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"response2: {response1}")
|
||||
assert response1.id == response2.id
|
||||
|
||||
|
||||
# test_redis_cache_completion()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="beta test - new redis semantic cache")
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_semantic_cache_acompletion():
|
||||
litellm.set_verbose = True
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
|
||||
print("testing semantic caching")
|
||||
litellm.cache = Cache(
|
||||
type="redis-semantic",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
similarity_threshold=0.8,
|
||||
redis_semantic_cache_use_async=True,
|
||||
)
|
||||
response1 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=5,
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
|
||||
random_number = random.randint(1, 100000)
|
||||
response2 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
max_tokens=5,
|
||||
)
|
||||
print(f"response2: {response2}")
|
||||
assert response1.id == response2.id
|
||||
|
|
|
@ -55,7 +55,7 @@ from .integrations.litedebugger import LiteDebugger
|
|||
from .proxy._types import KeyManagementSystem
|
||||
from openai import OpenAIError as OriginalError
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from .caching import S3Cache
|
||||
from .caching import S3Cache, RedisSemanticCache
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
|
@ -2533,6 +2533,14 @@ def client(original_function):
|
|||
):
|
||||
if len(cached_result) == 1 and cached_result[0] is None:
|
||||
cached_result = None
|
||||
elif isinstance(litellm.cache.cache, RedisSemanticCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs[
|
||||
"preset_cache_key"
|
||||
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
else:
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs[
|
||||
|
|
|
@ -9,6 +9,8 @@ uvicorn==0.22.0 # server dep
|
|||
gunicorn==21.2.0 # server dep
|
||||
boto3==1.28.58 # aws bedrock/sagemaker calls
|
||||
redis==4.6.0 # caching
|
||||
redisvl==0.0.7 # semantic caching
|
||||
numpy==1.24.3 # semantic caching
|
||||
prisma==0.11.0 # for db
|
||||
mangum==0.17.0 # for aws lambda functions
|
||||
google-generativeai==0.3.2 # for vertex ai calls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue