diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 6e18485f4..c6500c557 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -9,20 +9,27 @@ import litellm from litellm import embedding, completion litellm.caching = True -messages = [{"role": "user", "content": "Hey, how's it going?"}] +messages = [{"role": "user", "content": "who is ishaan Github? "}] # test if response cached -try: - response1 = completion(model="gpt-3.5-turbo", messages=messages) - response2 = completion(model="gpt-3.5-turbo", messages=messages) - if response2 != response1: +def test_caching(): + try: + response1 = completion(model="gpt-3.5-turbo", messages=messages) + response2 = completion(model="gpt-3.5-turbo", messages=messages) print(f"response1: {response1}") print(f"response2: {response2}") - raise Exception -except Exception as e: - print(f"error occurred: {traceback.format_exc()}") - pytest.fail(f"Error occurred: {e}") -litellm.caching = False + litellm.caching = False + if response2 != response1: + print(f"response1: {response1}") + print(f"response2: {response2}") + pytest.fail(f"Error occurred: {e}") + except Exception as e: + litellm.caching = False + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") + + + diff --git a/litellm/tests/test_client.py b/litellm/tests/test_client.py index b329e4c65..3c591d4cd 100644 --- a/litellm/tests/test_client.py +++ b/litellm/tests/test_client.py @@ -30,7 +30,7 @@ def test_completion_openai(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_completion_openai() + def test_completion_claude(): try: @@ -38,14 +38,14 @@ def test_completion_claude(): # Add any assertions here to check the response except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_claude() + def test_completion_non_openai(): try: response = completion(model="command-nightly", messages=messages, logger_fn=logger_fn) # Add any assertions here to check the response except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_non_openai() + def test_embedding_openai(): try: response = embedding(model='text-embedding-ada-002', input=[user_message], logger_fn=logger_fn) diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index ce83ffc70..a31d2a4fa 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -8,8 +8,8 @@ import litellm from litellm import embedding, completion from infisical import InfisicalClient -# litellm.set_verbose = True -litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) +# # litellm.set_verbose = True +# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) def test_openai_embedding(): try: diff --git a/litellm/utils.py b/litellm/utils.py index e5f803e31..86f07add3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -137,6 +137,36 @@ def client(original_function): #[Non-Blocking Error] pass + def get_prompt(*args, **kwargs): + # make this safe checks, it should not throw any exceptions + if len(args) > 1: + messages = args[1] + prompt = " ".join(message["content"] for message in messages) + return prompt + if "messages" in kwargs: + messages = kwargs["messages"] + prompt = " ".join(message["content"] for message in messages) + return prompt + return None + + def check_cache(*args, **kwargs): + try: # never block execution + prompt = get_prompt(*args, **kwargs) + if prompt != None and prompt in local_cache: # check if messages / prompt exists + result = local_cache[prompt] + return result + else: + return None + except: + return None + + def add_cache(result, *args, **kwargs): + try: # never block execution + prompt = get_prompt(*args, **kwargs) + local_cache[prompt] = result + except: + pass + def wrapper(*args, **kwargs): start_time = None result = None @@ -144,17 +174,14 @@ def client(original_function): function_setup(*args, **kwargs) ## MODEL CALL start_time = datetime.datetime.now() - ## CHECK CACHE RESPONSES - messages = args[1] if len(args) > 1 else kwargs["messages"] - prompt = " ".join(message["content"] for message in messages) - if litellm.caching and prompt in local_cache: - result = local_cache[prompt] + if litellm.caching and (cached_result := check_cache(*args, **kwargs)) is not None: + result = cached_result else: - result = original_function(*args, **kwargs) + result = original_function(*args, **kwargs) end_time = datetime.datetime.now() - ## CACHE RESPONSES + ## Add response to CACHE if litellm.caching: - local_cache[prompt] = result + add_cache(result, *args, **kwargs) ## LOG SUCCESS crash_reporting(*args, **kwargs) my_thread = threading.Thread(target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread