litellm/litellm/caching/disk_cache.py
Krish Dholakia 9160d80fa5
LiteLLM Minor Fixes & Improvements (11/12/2024) (#6705)
* fix(caching): convert arg to equivalent kwargs in llm caching handler

prevent unexpected errors

* fix(caching_handler.py): don't pass args to caching

* fix(caching): remove all *args from caching.py

* fix(caching): consistent function signatures + abc method

* test(caching_unit_tests.py): add unit tests for llm caching

ensures coverage for common caching scenarios across different implementations

* refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one

* fix(router.py): drop redis password requirement

* fix(proxy_server.py): fix faulty slack alerting check

* fix(langfuse.py): avoid copying functions/thread lock objects in metadata

fixes metadata copy error when parent otel span in metadata

* test: update test
2024-11-12 22:50:51 +05:30

90 lines
2.8 KiB
Python

import json
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import print_verbose
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list:
if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response) # type: ignore
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value # type: ignore
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value # type: ignore
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)