forked from phoenix/litellm-mirror
Merge pull request #3266 from antonioloison/litellm_add_disk_cache
[Feature] Add cache to disk
This commit is contained in:
commit
0c8f5e5649
6 changed files with 231 additions and 27 deletions
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Caching - In-Memory, Redis, s3, Redis Semantic Cache
|
||||
# Caching - In-Memory, Redis, s3, Redis Semantic Cache, Disk
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
|
||||
|
||||
|
@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht
|
|||
|
||||
:::
|
||||
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
@ -159,7 +159,7 @@ litellm.cache = Cache()
|
|||
# Make completion calls
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}]
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
caching=True
|
||||
)
|
||||
response2 = completion(
|
||||
|
@ -174,6 +174,43 @@ response2 = completion(
|
|||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="disk" label="disk cache">
|
||||
|
||||
### Quick Start
|
||||
|
||||
Install diskcache:
|
||||
|
||||
```shell
|
||||
pip install diskcache
|
||||
```
|
||||
|
||||
Then you can use the disk cache as follows.
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm.caching import Cache
|
||||
litellm.cache = Cache(type="disk")
|
||||
|
||||
# Make completion calls
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
caching=True
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
caching=True
|
||||
)
|
||||
|
||||
# response1 == response2, response 1 is cached
|
||||
|
||||
```
|
||||
|
||||
If you run the code two times, response1 will use the cache from the first run that was stored in a cache file.
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
@ -191,13 +228,13 @@ Advanced Params
|
|||
|
||||
```python
|
||||
litellm.enable_cache(
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
@ -215,13 +252,13 @@ Update the Cache params
|
|||
|
||||
```python
|
||||
litellm.update_cache(
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
@ -276,22 +313,29 @@ cache.get_cache = get_cache
|
|||
```python
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"], # A list of litellm call types to cache for. Defaults to caching for all litellm call types.
|
||||
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
|
||||
# redis cache params
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
||||
namespace: Optional[str] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
|
||||
# s3 Bucket, boto3 configuration
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_path: Optional[str] = None, # if you wish to save to a spefic path
|
||||
s3_path: Optional[str] = None, # if you wish to save to a specific path
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
|
@ -299,7 +343,11 @@ def __init__(
|
|||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
**kwargs,
|
||||
|
||||
# disk cache params
|
||||
disk_cache_dir=None,
|
||||
|
||||
**kwargs
|
||||
):
|
||||
```
|
||||
|
|
@ -40,7 +40,7 @@ cache = Cache()
|
|||
|
||||
cache.add_cache(cache_key="test-key", result="1234")
|
||||
|
||||
cache.get_cache(cache_key="test-key)
|
||||
cache.get_cache(cache_key="test-key")
|
||||
```
|
||||
|
||||
## Caching with Streaming
|
||||
|
|
|
@ -189,7 +189,7 @@ const sidebars = {
|
|||
`observability/telemetry`,
|
||||
],
|
||||
},
|
||||
"caching/redis_cache",
|
||||
"caching/all_caches",
|
||||
{
|
||||
type: "category",
|
||||
label: "Tutorials",
|
||||
|
|
|
@ -1441,7 +1441,7 @@ class DualCache(BaseCache):
|
|||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -1484,13 +1484,14 @@ class Cache:
|
|||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
disk_cache_dir=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
|
@ -1536,6 +1537,8 @@ class Cache:
|
|||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "disk":
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
|
@ -1907,8 +1910,86 @@ class Cache:
|
|||
await self.cache.disconnect()
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
import diskcache as dc
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
if disk_cache_dir is None:
|
||||
self.disk_cache = dc.Cache(".litellm_cache")
|
||||
else:
|
||||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose("DiskCache: set_cache")
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
self.disk_cache.set(key, value)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
original_cached_response = self.disk_cache.get(key)
|
||||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.disk_cache.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.disk_cache.pop(key)
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -1937,7 +2018,7 @@ def enable_cache(
|
|||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local".
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
|
@ -1973,7 +2054,7 @@ def enable_cache(
|
|||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -2002,7 +2083,7 @@ def update_cache(
|
|||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local".
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
|
|
|
@ -599,7 +599,10 @@ def test_redis_cache_completion():
|
|||
)
|
||||
print("test2 for Redis Caching - non streaming")
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
||||
|
@ -653,7 +656,6 @@ def test_redis_cache_completion():
|
|||
assert response1.created == response2.created
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
|
||||
|
||||
# test_redis_cache_completion()
|
||||
|
||||
|
||||
|
@ -875,6 +877,80 @@ async def test_redis_cache_acompletion_stream_bedrock():
|
|||
print(e)
|
||||
raise e
|
||||
|
||||
def test_disk_cache_completion():
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="disk",
|
||||
)
|
||||
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is so great!",
|
||||
)
|
||||
# response2 is mocked to a different response from response1,
|
||||
# but the completion from the cache should be used instead of the mock
|
||||
# response since the input is the same as response1
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is awful!",
|
||||
)
|
||||
# Since the parameters are not the same as response1, response3 should actually
|
||||
# be the mock response
|
||||
response3 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
temperature=0.5,
|
||||
mock_response="This number is awful!",
|
||||
)
|
||||
|
||||
print("\nresponse 1", response1)
|
||||
print("\nresponse 2", response2)
|
||||
print("\nresponse 3", response3)
|
||||
# print("\nresponse 4", response4)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
# 1 & 2 should be exactly the same
|
||||
# 1 & 3 should be different, since input params are diff
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
!= response2["choices"][0]["message"]["content"]
|
||||
): # 1 and 2 should be the same
|
||||
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
== response3["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
|
||||
print(f"response1: {response1}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(
|
||||
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
|
||||
f" occurred:"
|
||||
)
|
||||
|
||||
assert response1.id == response2.id
|
||||
assert response1.created == response2.created
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -65,7 +65,6 @@ extra_proxy = [
|
|||
"resend"
|
||||
]
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
litellm = 'litellm:run_server'
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue