forked from phoenix/litellm-mirror
fix(router.py): fix caching for tracking cooldowns + usage
This commit is contained in:
parent
94c1d71b2c
commit
61fc76a8c4
5 changed files with 148 additions and 75 deletions
|
@ -23,6 +23,9 @@ def get_prompt(*args, **kwargs):
|
|||
return prompt
|
||||
return None
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
|
||||
class BaseCache:
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
|
@ -32,6 +35,34 @@ class BaseCache:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.cache_dict = {}
|
||||
self.ttl_dict = {}
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
self.cache_dict[key] = value
|
||||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
if isinstance(cached_response, dict):
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
def __init__(self, host, port, password):
|
||||
import redis
|
||||
|
@ -65,7 +96,58 @@ class RedisCache(BaseCache):
|
|||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
This updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
def __init__(self, in_memory_cache: InMemoryCache =None, redis_cache: RedisCache =None) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(f"get cache: cache key: {key}")
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if self.redis_cache is not None:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
#### DEPRECATED ####
|
||||
class HostedCache(BaseCache):
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
|
@ -91,33 +173,7 @@ class HostedCache(BaseCache):
|
|||
return cached_response
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.cache_dict = {}
|
||||
self.ttl_dict = {}
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
self.cache_dict[key] = value
|
||||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
||||
#### LiteLLM.Completion Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue