forked from phoenix/litellm-mirror
v0 of caching
This commit is contained in:
parent
8a76c80039
commit
8b43917792
4 changed files with 137 additions and 49 deletions
|
@ -1,5 +1,6 @@
|
||||||
import threading
|
import threading
|
||||||
from typing import Callable, List, Optional, Dict
|
from typing import Callable, List, Optional, Dict
|
||||||
|
from litellm.caching import Cache
|
||||||
|
|
||||||
input_callback: List[str] = []
|
input_callback: List[str] = []
|
||||||
success_callback: List[str] = []
|
success_callback: List[str] = []
|
||||||
|
@ -30,6 +31,7 @@ baseten_key: Optional[str] = None
|
||||||
use_client = False
|
use_client = False
|
||||||
logging = True
|
logging = True
|
||||||
caching = False
|
caching = False
|
||||||
|
cache: Optional[Cache] = None # set to litellm.caching Cache() object
|
||||||
caching_with_models = False # if you want the caching key to be model + prompt
|
caching_with_models = False # if you want the caching key to be model + prompt
|
||||||
model_alias_map: Dict[str, str] = {}
|
model_alias_map: Dict[str, str] = {}
|
||||||
model_cost = {
|
model_cost = {
|
||||||
|
|
81
litellm/caching.py
Normal file
81
litellm/caching.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import redis
|
||||||
|
import litellm, openai
|
||||||
|
|
||||||
|
def get_prompt(*args, **kwargs):
|
||||||
|
# make this safe checks, it should not throw any exceptions
|
||||||
|
if len(args) > 1:
|
||||||
|
messages = args[1]
|
||||||
|
prompt = " ".join(message["content"] for message in messages)
|
||||||
|
return prompt
|
||||||
|
if "messages" in kwargs:
|
||||||
|
messages = kwargs["messages"]
|
||||||
|
prompt = " ".join(message["content"] for message in messages)
|
||||||
|
return prompt
|
||||||
|
return None
|
||||||
|
|
||||||
|
class RedisCache():
|
||||||
|
import redis
|
||||||
|
def __init__(self, host, port, password):
|
||||||
|
# if users don't provider one, use the default litellm cache
|
||||||
|
self.redis_client = redis.Redis(host=host, port=port, password=password)
|
||||||
|
|
||||||
|
def set_cache(self, key, value):
|
||||||
|
self.redis_client.set(key, str(value))
|
||||||
|
|
||||||
|
def get_cache(self, key):
|
||||||
|
# TODO convert this to a ModelResponse object
|
||||||
|
return self.redis_client.get(key)
|
||||||
|
|
||||||
|
class InMemoryCache():
|
||||||
|
def __init__(self):
|
||||||
|
# if users don't provider one, use the default litellm cache
|
||||||
|
self.cache_dict = {}
|
||||||
|
|
||||||
|
def set_cache(self, key, value):
|
||||||
|
self.cache_dict[key] = value
|
||||||
|
|
||||||
|
def get_cache(self, key):
|
||||||
|
if key in self.cache_dict:
|
||||||
|
return self.cache_dict[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
class Cache():
|
||||||
|
def __init__(self, type="local", host="", port="", password=""):
|
||||||
|
if type == "redis":
|
||||||
|
self.cache = RedisCache(type, host, port, password)
|
||||||
|
if type == "local":
|
||||||
|
self.cache = InMemoryCache()
|
||||||
|
|
||||||
|
def check_cache(self, *args, **kwargs):
|
||||||
|
try: # never block execution
|
||||||
|
prompt = get_prompt(*args, **kwargs)
|
||||||
|
if prompt != None: # check if messages / prompt exists
|
||||||
|
if "model" in kwargs: # default to caching with `model + prompt` as key
|
||||||
|
cache_key = prompt + kwargs["model"]
|
||||||
|
return self.cache.get_cache(cache_key)
|
||||||
|
else:
|
||||||
|
return self.cache.get_cache(prompt)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_cache(self, result, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
prompt = get_prompt(*args, **kwargs)
|
||||||
|
if "model" in kwargs: # default to caching with `model + prompt` as key
|
||||||
|
cache_key = prompt + kwargs["model"]
|
||||||
|
self.cache.set_cache(cache_key, result)
|
||||||
|
else:
|
||||||
|
self.cache.set_cache(prompt, result)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion
|
from litellm import embedding, completion
|
||||||
|
from litellm.caching import Cache
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||||
|
|
||||||
|
@ -78,3 +79,50 @@ def test_gpt_cache():
|
||||||
|
|
||||||
|
|
||||||
# test_gpt_cache()
|
# test_gpt_cache()
|
||||||
|
|
||||||
|
|
||||||
|
####### Updated Caching as of Aug 28, 2023 ###################
|
||||||
|
messages = [{"role": "user", "content": "who is ishaan 5222"}]
|
||||||
|
def test_caching():
|
||||||
|
try:
|
||||||
|
litellm.cache = Cache()
|
||||||
|
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
litellm.cache = None # disable cache
|
||||||
|
if response2 != response1:
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# test_caching()
|
||||||
|
|
||||||
|
|
||||||
|
def test_caching_with_models():
|
||||||
|
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
|
||||||
|
litellm.cache = Cache()
|
||||||
|
print("test2 for caching")
|
||||||
|
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
response3 = completion(model="command-nightly", messages=messages)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
litellm.cache = None
|
||||||
|
if response3 == response2:
|
||||||
|
# if models are different, it should not return cached response
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
pytest.fail(f"Error occurred:")
|
||||||
|
if response1 != response2:
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
pytest.fail(f"Error occurred:")
|
||||||
|
|
||||||
|
# test_caching_with_models()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -393,50 +393,6 @@ def client(original_function):
|
||||||
# [Non-Blocking Error]
|
# [Non-Blocking Error]
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_prompt(*args, **kwargs):
|
|
||||||
# make this safe checks, it should not throw any exceptions
|
|
||||||
if len(args) > 1:
|
|
||||||
messages = args[1]
|
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
|
||||||
return prompt
|
|
||||||
if "messages" in kwargs:
|
|
||||||
messages = kwargs["messages"]
|
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
|
||||||
return prompt
|
|
||||||
return None
|
|
||||||
|
|
||||||
def check_cache(*args, **kwargs):
|
|
||||||
try: # never block execution
|
|
||||||
prompt = get_prompt(*args, **kwargs)
|
|
||||||
if prompt != None: # check if messages / prompt exists
|
|
||||||
if litellm.caching_with_models:
|
|
||||||
# if caching with model names is enabled, key is prompt + model name
|
|
||||||
if "model" in kwargs:
|
|
||||||
cache_key = prompt + kwargs["model"]
|
|
||||||
if cache_key in local_cache:
|
|
||||||
return local_cache[cache_key]
|
|
||||||
else: # caching only with prompts
|
|
||||||
if prompt in local_cache:
|
|
||||||
result = local_cache[prompt]
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
return None # default to return None
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add_cache(result, *args, **kwargs):
|
|
||||||
try: # never block execution
|
|
||||||
prompt = get_prompt(*args, **kwargs)
|
|
||||||
if litellm.caching_with_models: # caching with model + prompt
|
|
||||||
if "model" in kwargs:
|
|
||||||
cache_key = prompt + kwargs["model"]
|
|
||||||
local_cache[cache_key] = result
|
|
||||||
else: # caching based only on prompts
|
|
||||||
local_cache[prompt] = result
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
start_time = None
|
start_time = None
|
||||||
result = None
|
result = None
|
||||||
|
@ -446,19 +402,20 @@ def client(original_function):
|
||||||
kwargs["litellm_call_id"] = litellm_call_id
|
kwargs["litellm_call_id"] = litellm_call_id
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
# [OPTIONAL] CHECK CACHE
|
# [OPTIONAL] CHECK CACHE
|
||||||
if (litellm.caching or litellm.caching_with_models) and (
|
if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and (
|
||||||
cached_result := check_cache(*args, **kwargs)
|
cached_result := litellm.cache.check_cache(*args, **kwargs)
|
||||||
) is not None:
|
) is not None:
|
||||||
result = cached_result
|
result = cached_result
|
||||||
return result
|
return result
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = original_function(*args, **kwargs)
|
result = original_function(*args, **kwargs)
|
||||||
if "stream" in kwargs and kwargs["stream"] == True:
|
if "stream" in kwargs and kwargs["stream"] == True:
|
||||||
|
# TODO: Add to cache for streaming
|
||||||
return result
|
return result
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if litellm.caching or litellm.caching_with_models:
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
# LOG SUCCESS
|
# LOG SUCCESS
|
||||||
my_thread = threading.Thread(
|
my_thread = threading.Thread(
|
||||||
target=handle_success, args=(args, kwargs, result, start_time, end_time)
|
target=handle_success, args=(args, kwargs, result, start_time, end_time)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue