forked from phoenix/litellm-mirror
Router & Caching fixes:
- Add optional TTL to Cache parameters - Fix tpm and rpm caching in Router
This commit is contained in:
parent
ee2e186c62
commit
5e1e8820b4
2 changed files with 85 additions and 83 deletions
|
@ -11,6 +11,7 @@ import litellm
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def get_prompt(*args, **kwargs):
|
||||
# make this safe checks, it should not throw any exceptions
|
||||
if len(args) > 1:
|
||||
|
@ -23,81 +24,98 @@ def get_prompt(*args, **kwargs):
|
|||
return prompt
|
||||
return None
|
||||
|
||||
class RedisCache():
|
||||
|
||||
class BaseCache:
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
def __init__(self, host, port, password):
|
||||
import redis
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.redis_client = redis.Redis(host=host, port=port, password=password)
|
||||
|
||||
def set_cache(self, key, value):
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
try:
|
||||
self.redis_client.set(key, str(value))
|
||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
print("LiteLLM Caching: Got exception from REDIS: ", e)
|
||||
|
||||
def get_cache(self, key):
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
# TODO convert this to a ModelResponse object
|
||||
cached_response = self.redis_client.get(key)
|
||||
if cached_response!=None:
|
||||
if cached_response != None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
print("LiteLLM Caching: Got exception from REDIS: ", e)
|
||||
|
||||
class HostedCache():
|
||||
def set_cache(self, key, value):
|
||||
|
||||
class HostedCache(BaseCache):
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
print("LiteLLM Caching: TTL is not supported for hosted cache!")
|
||||
# make a post request to api.litellm.ai/set_cache
|
||||
import requests
|
||||
url = f"https://api.litellm.ai/set_cache?key={key}&value={str(value)}"
|
||||
requests.request("POST", url) # post request to set this in the hosted litellm cache
|
||||
requests.request("POST", url) # post request to set this in the hosted litellm cache
|
||||
|
||||
def get_cache(self, key):
|
||||
def get_cache(self, key, **kwargs):
|
||||
import requests
|
||||
url = f"https://api.litellm.ai/get_cache?key={key}"
|
||||
cached_response = requests.request("GET", url)
|
||||
cached_response = cached_response.text
|
||||
if cached_response == "NONE": # api.litellm.ai returns "NONE" if it's not a cache hit
|
||||
return None
|
||||
if cached_response!=None:
|
||||
if cached_response == "NONE": # api.litellm.ai returns "NONE" if it's not a cache hit
|
||||
return None
|
||||
if cached_response != None:
|
||||
try:
|
||||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
except:
|
||||
return cached_response
|
||||
|
||||
class InMemoryCache():
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.cache_dict = {}
|
||||
self.ttl_dict = {}
|
||||
|
||||
def set_cache(self, key, value):
|
||||
#print("in set cache for inmem")
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
self.cache_dict[key] = value
|
||||
#print(self.cache_dict)
|
||||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
def get_cache(self, key):
|
||||
#print("in get cache for inmem")
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
#print("got a cache hit")
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
return self.cache_dict[key]
|
||||
#print("got a cache miss")
|
||||
return None
|
||||
|
||||
class Cache():
|
||||
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type = "local",
|
||||
host = None,
|
||||
port = None,
|
||||
password = None
|
||||
):
|
||||
self,
|
||||
type="local",
|
||||
host=None,
|
||||
port=None,
|
||||
password=None
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
|
@ -151,9 +169,9 @@ class Cache():
|
|||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i+chunk_size]}}]}
|
||||
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]}
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
@ -166,16 +184,16 @@ class Cache():
|
|||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if "cache_key" in kwargs:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True:
|
||||
# if streaming is true and we got a cache hit, return a generator
|
||||
#print("cache hit and stream=True")
|
||||
#print(cached_result)
|
||||
# print("cache hit and stream=True")
|
||||
# print(cached_result)
|
||||
return self.generate_streaming_content(cached_result["choices"][0]['message']['content'])
|
||||
return cached_result
|
||||
except:
|
||||
|
@ -193,20 +211,14 @@ class Cache():
|
|||
None
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
# print("adding to cache", cache_key, result)
|
||||
# print(cache_key)
|
||||
if cache_key is not None:
|
||||
# print("adding to cache", cache_key, result)
|
||||
self.cache.set_cache(cache_key, result)
|
||||
self.cache.set_cache(cache_key, result, **kwargs)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue