Merge pull request #3266 from antonioloison/litellm_add_disk_cache

[Feature] Add cache to disk
This commit is contained in:
Ishaan Jaff 2024-05-14 09:24:01 -07:00 committed by GitHub
commit 0c8f5e5649
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 231 additions and 27 deletions

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Caching - In-Memory, Redis, s3, Redis Semantic Cache
# Caching - In-Memory, Redis, s3, Redis Semantic Cache, Disk
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht
:::
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache
<Tabs>
@ -159,7 +159,7 @@ litellm.cache = Cache()
# Make completion calls
response1 = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke."}]
messages=[{"role": "user", "content": "Tell me a joke."}],
caching=True
)
response2 = completion(
@ -174,6 +174,43 @@ response2 = completion(
</TabItem>
<TabItem value="disk" label="disk cache">
### Quick Start
Install diskcache:
```shell
pip install diskcache
```
Then you can use the disk cache as follows.
```python
import litellm
from litellm import completion
from litellm.caching import Cache
litellm.cache = Cache(type="disk")
# Make completion calls
response1 = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke."}],
caching=True
)
response2 = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke."}],
caching=True
)
# response1 == response2, response 1 is cached
```
If you run the code two times, response1 will use the cache from the first run that was stored in a cache file.
</TabItem>
</Tabs>
@ -191,13 +228,13 @@ Advanced Params
```python
litellm.enable_cache(
type: Optional[Literal["local", "redis"]] = "local",
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
**kwargs,
)
```
@ -215,13 +252,13 @@ Update the Cache params
```python
litellm.update_cache(
type: Optional[Literal["local", "redis"]] = "local",
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
**kwargs,
)
```
@ -276,22 +313,29 @@ cache.get_cache = get_cache
```python
def __init__(
self,
type: Optional[Literal["local", "redis", "s3"]] = "local",
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"], # A list of litellm call types to cache for. Defaults to caching for all litellm call types.
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
# redis cache params
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
namespace: Optional[str] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_path: Optional[str] = None, # if you wish to save to a spefic path
s3_path: Optional[str] = None, # if you wish to save to a specific path
s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None,
@ -299,7 +343,11 @@ def __init__(
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
**kwargs,
# disk cache params
disk_cache_dir=None,
**kwargs
):
```

View file

@ -40,7 +40,7 @@ cache = Cache()
cache.add_cache(cache_key="test-key", result="1234")
cache.get_cache(cache_key="test-key)
cache.get_cache(cache_key="test-key")
```
## Caching with Streaming

View file

@ -189,7 +189,7 @@ const sidebars = {
`observability/telemetry`,
],
},
"caching/redis_cache",
"caching/all_caches",
{
type: "category",
label: "Tutorials",

View file

@ -1441,7 +1441,7 @@ class DualCache(BaseCache):
class Cache:
def __init__(
self,
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -1484,13 +1484,14 @@ class Cache:
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
disk_cache_dir=None,
**kwargs,
):
"""
Initializes the cache based on the given type.
Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
@ -1536,6 +1537,8 @@ class Cache:
s3_path=s3_path,
**kwargs,
)
elif type == "disk":
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
@ -1907,8 +1910,86 @@ class Cache:
await self.cache.disconnect()
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)
def enable_cache(
type: Optional[Literal["local", "redis", "s3"]] = "local",
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -1937,7 +2018,7 @@ def enable_cache(
Enable cache with the specified configuration.
Args:
type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local".
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
host (Optional[str]): The host address of the cache server. Defaults to None.
port (Optional[str]): The port number of the cache server. Defaults to None.
password (Optional[str]): The password for the cache server. Defaults to None.
@ -1973,7 +2054,7 @@ def enable_cache(
def update_cache(
type: Optional[Literal["local", "redis"]] = "local",
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -2002,7 +2083,7 @@ def update_cache(
Update the cache for LiteLLM.
Args:
type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local".
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
host (Optional[str]): The host of the cache. Defaults to None.
port (Optional[str]): The port of the cache. Defaults to None.
password (Optional[str]): The password for the cache. Defaults to None.

View file

@ -599,7 +599,10 @@ def test_redis_cache_completion():
)
print("test2 for Redis Caching - non streaming")
response1 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
)
response2 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
@ -653,7 +656,6 @@ def test_redis_cache_completion():
assert response1.created == response2.created
assert response1.choices[0].message.content == response2.choices[0].message.content
# test_redis_cache_completion()
@ -875,6 +877,80 @@ async def test_redis_cache_acompletion_stream_bedrock():
print(e)
raise e
def test_disk_cache_completion():
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="disk",
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is awful!",
)
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
f" occurred:"
)
assert response1.id == response2.id
assert response1.created == response2.created
assert response1.choices[0].message.content == response2.choices[0].message.content
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio

View file

@ -65,7 +65,6 @@ extra_proxy = [
"resend"
]
[tool.poetry.scripts]
litellm = 'litellm:run_server'