import json from typing import TYPE_CHECKING, Any, Optional, Union from .base_cache import BaseCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span Span = Union[_Span, Any] else: Span = Any class DiskCache(BaseCache): def __init__(self, disk_cache_dir: Optional[str] = None, disk_cache_size_limit: Optional[float] = None, disk_cache_cull_limit: Optional[int] = None, ) : import diskcache as dc # if users don't provider one, use the default litellm cache disk_cache_params = {} if disk_cache_size_limit is not None: disk_cache_params["size_limit"] = disk_cache_size_limit if disk_cache_cull_limit is not None: disk_cache_params["cull_limit"] = disk_cache_cull_limit if disk_cache_dir is None: disk_cache_params["directory"] = ".litellm_cache" else: disk_cache_params["directory"] = disk_cache_dir self.disk_cache = dc.Cache(**disk_cache_params) def set_cache(self, key, value, **kwargs): if "ttl" in kwargs: self.disk_cache.set(key, value, expire=kwargs["ttl"]) else: self.disk_cache.set(key, value) async def async_set_cache(self, key, value, **kwargs): self.set_cache(key=key, value=value, **kwargs) async def async_set_cache_pipeline(self, cache_list, **kwargs): for cache_key, cache_value in cache_list: if "ttl" in kwargs: self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"]) else: self.set_cache(key=cache_key, value=cache_value) def get_cache(self, key, **kwargs): original_cached_response = self.disk_cache.get(key) if original_cached_response: try: cached_response = json.loads(original_cached_response) # type: ignore except Exception: cached_response = original_cached_response return cached_response return None def batch_get_cache(self, keys: list, **kwargs): return_val = [] for k in keys: val = self.get_cache(key=k, **kwargs) return_val.append(val) return return_val def increment_cache(self, key, value: int, **kwargs) -> int: # get the value init_value = self.get_cache(key=key) or 0 value = init_value + value # type: ignore self.set_cache(key, value, **kwargs) return value async def async_get_cache(self, key, **kwargs): return self.get_cache(key=key, **kwargs) async def async_batch_get_cache(self, keys: list, **kwargs): return_val = [] for k in keys: val = self.get_cache(key=k, **kwargs) return_val.append(val) return return_val async def async_increment(self, key, value: int, **kwargs) -> int: # get the value init_value = await self.async_get_cache(key=key) or 0 value = init_value + value # type: ignore await self.async_set_cache(key, value, **kwargs) return value def flush_cache(self): self.disk_cache.clear() async def disconnect(self): pass def delete_cache(self, key): self.disk_cache.pop(key)