diff --git a/litellm/__init__.py b/litellm/__init__.py index f607998e1f..688cd084fd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -20,6 +20,7 @@ vertex_location: Optional[str] = None hugging_api_token: Optional[str] = None togetherai_api_key: Optional[str] = None caching = False +caching_with_models = False # if you want the caching key to be model + prompt model_cost = { "gpt-3.5-turbo": { "max_tokens": 4000, diff --git a/litellm/tests/test_bad_params.py b/litellm/tests/test_bad_params.py index 92432307be..bf05de8bd9 100644 --- a/litellm/tests/test_bad_params.py +++ b/litellm/tests/test_bad_params.py @@ -50,4 +50,4 @@ try: except: print(f"error occurred: {traceback.format_exc()}") pass -os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5 +os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5 diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 5d7e962cf5..87ae02b58f 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -12,13 +12,13 @@ import pytest import litellm from litellm import embedding, completion -litellm.caching = True messages = [{"role": "user", "content": "who is ishaan Github? "}] # test if response cached def test_caching(): try: + litellm.caching = True response1 = completion(model="gpt-3.5-turbo", messages=messages) response2 = completion(model="gpt-3.5-turbo", messages=messages) print(f"response1: {response1}") @@ -32,3 +32,21 @@ def test_caching(): litellm.caching = False print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") + + + +def test_caching_with_models(): + litellm.caching_with_models = True + response2 = completion(model="gpt-3.5-turbo", messages=messages) + response3 = completion(model="command-nightly", messages=messages) + print(f"response2: {response2}") + print(f"response3: {response3}") + litellm.caching_with_models = False + if response3 == response2: + # if models are different, it should not return cached response + print(f"response2: {response2}") + print(f"response3: {response3}") + pytest.fail(f"Error occurred: {e}") + + + diff --git a/litellm/utils.py b/litellm/utils.py index b79067c36c..290e64ddba 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -263,8 +263,17 @@ def client(original_function): if ( prompt != None and prompt in local_cache ): # check if messages / prompt exists - result = local_cache[prompt] - return result + if litellm.caching_with_models: + # if caching with model names is enabled, key is prompt + model name + if ( + "model" in kwargs + and kwargs["model"] in local_cache[prompt]["models"] + ): + cache_key = prompt + kwargs["model"] + return local_cache[cache_key] + else: # caching only with prompts + result = local_cache[prompt] + return result else: return None except: @@ -273,7 +282,15 @@ def client(original_function): def add_cache(result, *args, **kwargs): try: # never block execution prompt = get_prompt(*args, **kwargs) - local_cache[prompt] = result + if litellm.caching_with_models: # caching with model + prompt + if ( + "model" in kwargs + and kwargs["model"] in local_cache[prompt]["models"] + ): + cache_key = prompt + kwargs["model"] + local_cache[cache_key] = result + else: # caching based only on prompts + local_cache[prompt] = result except: pass @@ -284,10 +301,9 @@ def client(original_function): function_setup(*args, **kwargs) ## MODEL CALL start_time = datetime.datetime.now() - if ( - litellm.caching - and (cached_result := check_cache(*args, **kwargs)) is not None - ): + if (litellm.caching or litellm.caching_with_models) and ( + cached_result := check_cache(*args, **kwargs) + ) is not None: result = cached_result else: result = original_function(*args, **kwargs)