Router & Caching fixes:

- Add optional TTL to Cache parameters
- Fix tpm and rpm caching in Router
This commit is contained in:
seva 2023-10-30 13:29:35 +01:00
parent ee2e186c62
commit 5e1e8820b4
2 changed files with 85 additions and 83 deletions

View file

@ -11,6 +11,7 @@ import litellm
import time
import json
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
@ -23,81 +24,98 @@ def get_prompt(*args, **kwargs):
return prompt
return None
class RedisCache():
class BaseCache:
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
def get_cache(self, key, **kwargs):
raise NotImplementedError
class RedisCache(BaseCache):
def __init__(self, host, port, password):
import redis
# if users don't provider one, use the default litellm cache
self.redis_client = redis.Redis(host=host, port=port, password=password)
def set_cache(self, key, value):
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
try:
self.redis_client.set(key, str(value))
self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print("LiteLLM Caching: Got exception from REDIS: ", e)
def get_cache(self, key):
def get_cache(self, key, **kwargs):
try:
# TODO convert this to a ModelResponse object
cached_response = self.redis_client.get(key)
if cached_response!=None:
if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
cached_response = json.loads(cached_response) # Convert string to dictionary
cached_response['cache'] = True # set cache-hit flag to True
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print("LiteLLM Caching: Got exception from REDIS: ", e)
class HostedCache():
def set_cache(self, key, value):
class HostedCache(BaseCache):
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
print("LiteLLM Caching: TTL is not supported for hosted cache!")
# make a post request to api.litellm.ai/set_cache
import requests
url = f"https://api.litellm.ai/set_cache?key={key}&value={str(value)}"
requests.request("POST", url) # post request to set this in the hosted litellm cache
requests.request("POST", url) # post request to set this in the hosted litellm cache
def get_cache(self, key):
def get_cache(self, key, **kwargs):
import requests
url = f"https://api.litellm.ai/get_cache?key={key}"
cached_response = requests.request("GET", url)
cached_response = cached_response.text
if cached_response == "NONE": # api.litellm.ai returns "NONE" if it's not a cache hit
return None
if cached_response!=None:
if cached_response == "NONE": # api.litellm.ai returns "NONE" if it's not a cache hit
return None
if cached_response != None:
try:
cached_response = json.loads(cached_response) # Convert string to dictionary
cached_response['cache'] = True # set cache-hit flag to True
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
except:
return cached_response
class InMemoryCache():
class InMemoryCache(BaseCache):
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value):
#print("in set cache for inmem")
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
#print(self.cache_dict)
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key):
#print("in get cache for inmem")
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
#print("got a cache hit")
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
return self.cache_dict[key]
#print("got a cache miss")
return None
class Cache():
class Cache:
def __init__(
self,
type = "local",
host = None,
port = None,
password = None
):
self,
type="local",
host=None,
port=None,
password=None
):
"""
Initializes the cache based on the given type.
@ -151,9 +169,9 @@ class Cache():
def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size):
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i+chunk_size]}}]}
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]}
time.sleep(0.02)
def get_cache(self, *args, **kwargs):
"""
Retrieves the cached result for the given arguments.
@ -166,16 +184,16 @@ class Cache():
The cached result if it exists, otherwise None.
"""
try: # never block execution
if "cache_key" in kwargs:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cached_result = self.cache.get_cache(cache_key)
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True:
# if streaming is true and we got a cache hit, return a generator
#print("cache hit and stream=True")
#print(cached_result)
# print("cache hit and stream=True")
# print(cached_result)
return self.generate_streaming_content(cached_result["choices"][0]['message']['content'])
return cached_result
except:
@ -193,20 +211,14 @@ class Cache():
None
"""
try:
if "cache_key" in kwargs:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
else:
cache_key = self.get_cache_key(*args, **kwargs)
# print("adding to cache", cache_key, result)
# print(cache_key)
if cache_key is not None:
# print("adding to cache", cache_key, result)
self.cache.set_cache(cache_key, result)
self.cache.set_cache(cache_key, result, **kwargs)
except:
pass