fix(router.py): fix caching for tracking cooldowns + usage

This commit is contained in:
Krrish Dholakia 2023-11-23 11:13:24 -08:00
parent 94c1d71b2c
commit 61fc76a8c4
5 changed files with 148 additions and 75 deletions

View file

@ -23,6 +23,9 @@ def get_prompt(*args, **kwargs):
return prompt
return None
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class BaseCache:
def set_cache(self, key, value, **kwargs):
@ -32,6 +35,34 @@ class BaseCache:
raise NotImplementedError
class InMemoryCache(BaseCache):
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
class RedisCache(BaseCache):
def __init__(self, host, port, password):
import redis
@ -65,7 +96,58 @@ class RedisCache(BaseCache):
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
class DualCache(BaseCache):
"""
This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(self, in_memory_cache: InMemoryCache =None, redis_cache: RedisCache =None) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs):
# Update both Redis and in-memory cache
try:
print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def get_cache(self, key, **kwargs):
# Try to fetch from in-memory cache first
try:
print_verbose(f"get cache: cache key: {key}")
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
if in_memory_result is not None:
result = in_memory_result
if self.redis_cache is not None:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
#### DEPRECATED ####
class HostedCache(BaseCache):
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
@ -91,33 +173,7 @@ class HostedCache(BaseCache):
return cached_response
class InMemoryCache(BaseCache):
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
#### LiteLLM.Completion Cache ####
class Cache:
def __init__(
self,