fixes to caching, testing caching

This commit is contained in:
ishaan-jaff 2023-08-17 10:41:14 -07:00
parent cff26b1d08
commit 8b245278b2
4 changed files with 57 additions and 23 deletions

View file

@ -9,20 +9,27 @@ import litellm
from litellm import embedding, completion
litellm.caching = True
messages = [{"role": "user", "content": "Hey, how's it going?"}]
messages = [{"role": "user", "content": "who is ishaan Github? "}]
# test if response cached
try:
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2 != response1:
def test_caching():
try:
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
litellm.caching = False
litellm.caching = False
if response2 != response1:
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred: {e}")
except Exception as e:
litellm.caching = False
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")

View file

@ -30,7 +30,7 @@ def test_completion_openai():
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
test_completion_openai()
def test_completion_claude():
try:
@ -38,14 +38,14 @@ def test_completion_claude():
# Add any assertions here to check the response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_claude()
def test_completion_non_openai():
try:
response = completion(model="command-nightly", messages=messages, logger_fn=logger_fn)
# Add any assertions here to check the response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_non_openai()
def test_embedding_openai():
try:
response = embedding(model='text-embedding-ada-002', input=[user_message], logger_fn=logger_fn)

View file

@ -8,8 +8,8 @@ import litellm
from litellm import embedding, completion
from infisical import InfisicalClient
# litellm.set_verbose = True
litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
# # litellm.set_verbose = True
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
def test_openai_embedding():
try:

View file

@ -137,6 +137,36 @@ def client(original_function):
#[Non-Blocking Error]
pass
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
messages = args[1]
prompt = " ".join(message["content"] for message in messages)
return prompt
if "messages" in kwargs:
messages = kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
return prompt
return None
def check_cache(*args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if prompt != None and prompt in local_cache: # check if messages / prompt exists
result = local_cache[prompt]
return result
else:
return None
except:
return None
def add_cache(result, *args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
local_cache[prompt] = result
except:
pass
def wrapper(*args, **kwargs):
start_time = None
result = None
@ -144,17 +174,14 @@ def client(original_function):
function_setup(*args, **kwargs)
## MODEL CALL
start_time = datetime.datetime.now()
## CHECK CACHE RESPONSES
messages = args[1] if len(args) > 1 else kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
if litellm.caching and prompt in local_cache:
result = local_cache[prompt]
if litellm.caching and (cached_result := check_cache(*args, **kwargs)) is not None:
result = cached_result
else:
result = original_function(*args, **kwargs)
result = original_function(*args, **kwargs)
end_time = datetime.datetime.now()
## CACHE RESPONSES
## Add response to CACHE
if litellm.caching:
local_cache[prompt] = result
add_cache(result, *args, **kwargs)
## LOG SUCCESS
crash_reporting(*args, **kwargs)
my_thread = threading.Thread(target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread