litellm-mirror/litellm/caching/disk_cache.py
Krish Dholakia 4f8a3fd4cf
redis otel tracing + async support for latency routing (#6452)
* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param
2024-10-28 21:52:12 -07:00

91 lines
2.8 KiB
Python

import json
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import print_verbose
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response) # type: ignore
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value # type: ignore
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value # type: ignore
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)