diff --git a/litellm/__init__.py b/litellm/__init__.py index 227a7d662..363cb6205 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,5 +1,6 @@ import threading from typing import Callable, List, Optional, Dict +from litellm.caching import Cache input_callback: List[str] = [] success_callback: List[str] = [] @@ -30,6 +31,7 @@ baseten_key: Optional[str] = None use_client = False logging = True caching = False +cache: Optional[Cache] = None # set to litellm.caching Cache() object caching_with_models = False # if you want the caching key to be model + prompt model_alias_map: Dict[str, str] = {} model_cost = { diff --git a/litellm/caching.py b/litellm/caching.py new file mode 100644 index 000000000..c7ed6a258 --- /dev/null +++ b/litellm/caching.py @@ -0,0 +1,81 @@ +import redis +import litellm, openai + +def get_prompt(*args, **kwargs): + # make this safe checks, it should not throw any exceptions + if len(args) > 1: + messages = args[1] + prompt = " ".join(message["content"] for message in messages) + return prompt + if "messages" in kwargs: + messages = kwargs["messages"] + prompt = " ".join(message["content"] for message in messages) + return prompt + return None + +class RedisCache(): + import redis + def __init__(self, host, port, password): + # if users don't provider one, use the default litellm cache + self.redis_client = redis.Redis(host=host, port=port, password=password) + + def set_cache(self, key, value): + self.redis_client.set(key, str(value)) + + def get_cache(self, key): + # TODO convert this to a ModelResponse object + return self.redis_client.get(key) + +class InMemoryCache(): + def __init__(self): + # if users don't provider one, use the default litellm cache + self.cache_dict = {} + + def set_cache(self, key, value): + self.cache_dict[key] = value + + def get_cache(self, key): + if key in self.cache_dict: + return self.cache_dict[key] + return None + +class Cache(): + def __init__(self, type="local", host="", port="", password=""): + if type == "redis": + self.cache = RedisCache(type, host, port, password) + if type == "local": + self.cache = InMemoryCache() + + def check_cache(self, *args, **kwargs): + try: # never block execution + prompt = get_prompt(*args, **kwargs) + if prompt != None: # check if messages / prompt exists + if "model" in kwargs: # default to caching with `model + prompt` as key + cache_key = prompt + kwargs["model"] + return self.cache.get_cache(cache_key) + else: + return self.cache.get_cache(prompt) + except: + return None + + def add_cache(self, result, *args, **kwargs): + try: + prompt = get_prompt(*args, **kwargs) + if "model" in kwargs: # default to caching with `model + prompt` as key + cache_key = prompt + kwargs["model"] + self.cache.set_cache(cache_key, result) + else: + self.cache.set_cache(prompt, result) + except: + pass + + + + + + + + + + + diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index ef9b24a43..fdd9fe798 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -11,6 +11,7 @@ sys.path.insert( import pytest import litellm from litellm import embedding, completion +from litellm.caching import Cache messages = [{"role": "user", "content": "who is ishaan Github? "}] @@ -78,3 +79,50 @@ def test_gpt_cache(): # test_gpt_cache() + + +####### Updated Caching as of Aug 28, 2023 ################### +messages = [{"role": "user", "content": "who is ishaan 5222"}] +def test_caching(): + try: + litellm.cache = Cache() + response1 = completion(model="gpt-3.5-turbo", messages=messages) + response2 = completion(model="gpt-3.5-turbo", messages=messages) + print(f"response1: {response1}") + print(f"response2: {response2}") + litellm.cache = None # disable cache + if response2 != response1: + print(f"response1: {response1}") + print(f"response2: {response2}") + pytest.fail(f"Error occurred: {e}") + except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") + +# test_caching() + + +def test_caching_with_models(): + messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}] + litellm.cache = Cache() + print("test2 for caching") + response1 = completion(model="gpt-3.5-turbo", messages=messages) + response2 = completion(model="gpt-3.5-turbo", messages=messages) + response3 = completion(model="command-nightly", messages=messages) + print(f"response1: {response1}") + print(f"response2: {response2}") + print(f"response3: {response3}") + litellm.cache = None + if response3 == response2: + # if models are different, it should not return cached response + print(f"response2: {response2}") + print(f"response3: {response3}") + pytest.fail(f"Error occurred:") + if response1 != response2: + print(f"response1: {response1}") + print(f"response2: {response2}") + pytest.fail(f"Error occurred:") + +# test_caching_with_models() + + diff --git a/litellm/utils.py b/litellm/utils.py index fb136a330..7746a82d8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -393,50 +393,6 @@ def client(original_function): # [Non-Blocking Error] pass - def get_prompt(*args, **kwargs): - # make this safe checks, it should not throw any exceptions - if len(args) > 1: - messages = args[1] - prompt = " ".join(message["content"] for message in messages) - return prompt - if "messages" in kwargs: - messages = kwargs["messages"] - prompt = " ".join(message["content"] for message in messages) - return prompt - return None - - def check_cache(*args, **kwargs): - try: # never block execution - prompt = get_prompt(*args, **kwargs) - if prompt != None: # check if messages / prompt exists - if litellm.caching_with_models: - # if caching with model names is enabled, key is prompt + model name - if "model" in kwargs: - cache_key = prompt + kwargs["model"] - if cache_key in local_cache: - return local_cache[cache_key] - else: # caching only with prompts - if prompt in local_cache: - result = local_cache[prompt] - return result - else: - return None - return None # default to return None - except: - return None - - def add_cache(result, *args, **kwargs): - try: # never block execution - prompt = get_prompt(*args, **kwargs) - if litellm.caching_with_models: # caching with model + prompt - if "model" in kwargs: - cache_key = prompt + kwargs["model"] - local_cache[cache_key] = result - else: # caching based only on prompts - local_cache[prompt] = result - except: - pass - def wrapper(*args, **kwargs): start_time = None result = None @@ -446,19 +402,20 @@ def client(original_function): kwargs["litellm_call_id"] = litellm_call_id start_time = datetime.datetime.now() # [OPTIONAL] CHECK CACHE - if (litellm.caching or litellm.caching_with_models) and ( - cached_result := check_cache(*args, **kwargs) + if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and ( + cached_result := litellm.cache.check_cache(*args, **kwargs) ) is not None: result = cached_result return result # MODEL CALL result = original_function(*args, **kwargs) if "stream" in kwargs and kwargs["stream"] == True: + # TODO: Add to cache for streaming return result end_time = datetime.datetime.now() # [OPTIONAL] ADD TO CACHE - if litellm.caching or litellm.caching_with_models: - add_cache(result, *args, **kwargs) + if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object + litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS my_thread = threading.Thread( target=handle_success, args=(args, kwargs, result, start_time, end_time) @@ -1730,4 +1687,4 @@ def completion_with_fallbacks(**kwargs): ) # cool down this selected model # print(f"rate_limited_models {rate_limited_models}") pass - return response + return response \ No newline at end of file