fixes to caching, testing caching

This commit is contained in:
ishaan-jaff 2023-08-17 10:41:14 -07:00
parent cff26b1d08
commit 8b245278b2
4 changed files with 57 additions and 23 deletions

View file

@ -9,20 +9,27 @@ import litellm
from litellm import embedding, completion from litellm import embedding, completion
litellm.caching = True litellm.caching = True
messages = [{"role": "user", "content": "Hey, how's it going?"}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
# test if response cached # test if response cached
try: def test_caching():
response1 = completion(model="gpt-3.5-turbo", messages=messages) try:
response2 = completion(model="gpt-3.5-turbo", messages=messages) response1 = completion(model="gpt-3.5-turbo", messages=messages)
if response2 != response1: response2 = completion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
raise Exception litellm.caching = False
except Exception as e: if response2 != response1:
print(f"error occurred: {traceback.format_exc()}") print(f"response1: {response1}")
pytest.fail(f"Error occurred: {e}") print(f"response2: {response2}")
litellm.caching = False pytest.fail(f"Error occurred: {e}")
except Exception as e:
litellm.caching = False
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")

View file

@ -30,7 +30,7 @@ def test_completion_openai():
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_openai()
def test_completion_claude(): def test_completion_claude():
try: try:
@ -38,14 +38,14 @@ def test_completion_claude():
# Add any assertions here to check the response # Add any assertions here to check the response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_claude()
def test_completion_non_openai(): def test_completion_non_openai():
try: try:
response = completion(model="command-nightly", messages=messages, logger_fn=logger_fn) response = completion(model="command-nightly", messages=messages, logger_fn=logger_fn)
# Add any assertions here to check the response # Add any assertions here to check the response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_non_openai()
def test_embedding_openai(): def test_embedding_openai():
try: try:
response = embedding(model='text-embedding-ada-002', input=[user_message], logger_fn=logger_fn) response = embedding(model='text-embedding-ada-002', input=[user_message], logger_fn=logger_fn)

View file

@ -8,8 +8,8 @@ import litellm
from litellm import embedding, completion from litellm import embedding, completion
from infisical import InfisicalClient from infisical import InfisicalClient
# litellm.set_verbose = True # # litellm.set_verbose = True
litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) # litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
def test_openai_embedding(): def test_openai_embedding():
try: try:

View file

@ -137,6 +137,36 @@ def client(original_function):
#[Non-Blocking Error] #[Non-Blocking Error]
pass pass
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
messages = args[1]
prompt = " ".join(message["content"] for message in messages)
return prompt
if "messages" in kwargs:
messages = kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
return prompt
return None
def check_cache(*args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if prompt != None and prompt in local_cache: # check if messages / prompt exists
result = local_cache[prompt]
return result
else:
return None
except:
return None
def add_cache(result, *args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
local_cache[prompt] = result
except:
pass
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = None start_time = None
result = None result = None
@ -144,17 +174,14 @@ def client(original_function):
function_setup(*args, **kwargs) function_setup(*args, **kwargs)
## MODEL CALL ## MODEL CALL
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
## CHECK CACHE RESPONSES if litellm.caching and (cached_result := check_cache(*args, **kwargs)) is not None:
messages = args[1] if len(args) > 1 else kwargs["messages"] result = cached_result
prompt = " ".join(message["content"] for message in messages)
if litellm.caching and prompt in local_cache:
result = local_cache[prompt]
else: else:
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
## CACHE RESPONSES ## Add response to CACHE
if litellm.caching: if litellm.caching:
local_cache[prompt] = result add_cache(result, *args, **kwargs)
## LOG SUCCESS ## LOG SUCCESS
crash_reporting(*args, **kwargs) crash_reporting(*args, **kwargs)
my_thread = threading.Thread(target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread my_thread = threading.Thread(target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread