adding exact match caching

This commit is contained in:
Krrish Dholakia 2023-08-16 21:44:50 -07:00
parent 4d475793ee
commit 79bcb59e0b
7 changed files with 41 additions and 3 deletions

View file

@ -15,7 +15,7 @@ openrouter_key = None
huggingface_key = None huggingface_key = None
vertex_project = None vertex_project = None
vertex_location = None vertex_location = None
caching = False
hugging_api_token = None hugging_api_token = None
model_cost = { model_cost = {
"gpt-3.5-turbo": {"max_tokens": 4000, "input_cost_per_token": 0.0000015, "output_cost_per_token": 0.000002}, "gpt-3.5-turbo": {"max_tokens": 4000, "input_cost_per_token": 0.0000015, "output_cost_per_token": 0.000002},

View file

@ -0,0 +1,27 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion
litellm.caching = True
messages = [{"role": "user", "content": "Hey, how's it going?"}]
# test if response cached
try:
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2 != response1:
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")

View file

@ -213,7 +213,7 @@ def test_completion_together_ai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_together_ai_stream()
def test_petals(): def test_petals():
model_name = "stabilityai/StableBeluga2" model_name = "stabilityai/StableBeluga2"
try: try:

View file

@ -28,6 +28,7 @@ supabaseClient = None
callback_list = [] callback_list = []
user_logger_fn = None user_logger_fn = None
additional_details = {} additional_details = {}
local_cache = {}
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
@ -138,12 +139,22 @@ def client(original_function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = None start_time = None
result = None
try: try:
function_setup(*args, **kwargs) function_setup(*args, **kwargs)
## MODEL CALL ## MODEL CALL
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
## CHECK CACHE RESPONSES
messages = args[1] if len(args) > 1 else kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
if litellm.caching and prompt in local_cache:
result = local_cache[prompt]
else:
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
## CACHE RESPONSES
if litellm.caching:
local_cache[prompt] = result
## LOG SUCCESS ## LOG SUCCESS
crash_reporting(*args, **kwargs) crash_reporting(*args, **kwargs)
my_thread = threading.Thread(target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread my_thread = threading.Thread(target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread