Merge pull request #3266 from antonioloison/litellm_add_disk_cache

[Feature] Add cache to disk
This commit is contained in:
Ishaan Jaff 2024-05-14 09:24:01 -07:00 committed by GitHub
commit a3fb6e8c34
6 changed files with 231 additions and 27 deletions

View file

@ -1441,7 +1441,7 @@ class DualCache(BaseCache):
class Cache:
def __init__(
self,
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -1484,13 +1484,14 @@ class Cache:
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
disk_cache_dir=None,
**kwargs,
):
"""
Initializes the cache based on the given type.
Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
@ -1536,6 +1537,8 @@ class Cache:
s3_path=s3_path,
**kwargs,
)
elif type == "disk":
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
@ -1907,8 +1910,86 @@ class Cache:
await self.cache.disconnect()
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)
def enable_cache(
type: Optional[Literal["local", "redis", "s3"]] = "local",
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -1937,7 +2018,7 @@ def enable_cache(
Enable cache with the specified configuration.
Args:
type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local".
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
host (Optional[str]): The host address of the cache server. Defaults to None.
port (Optional[str]): The port number of the cache server. Defaults to None.
password (Optional[str]): The password for the cache server. Defaults to None.
@ -1973,7 +2054,7 @@ def enable_cache(
def update_cache(
type: Optional[Literal["local", "redis"]] = "local",
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -2002,7 +2083,7 @@ def update_cache(
Update the cache for LiteLLM.
Args:
type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local".
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
host (Optional[str]): The host of the cache. Defaults to None.
port (Optional[str]): The port of the cache. Defaults to None.
password (Optional[str]): The password for the cache. Defaults to None.