test(test_caching.py): reset cache values at the end of test

This commit is contained in:
Krrish Dholakia 2023-12-11 18:10:46 -08:00
parent 634d301cae
commit 6cb4ef5659
3 changed files with 47 additions and 13 deletions

2
.gitignore vendored
View file

@ -20,3 +20,5 @@ litellm/tests/aiologs.log
litellm/tests/exception_data.txt litellm/tests/exception_data.txt
litellm/tests/config_*.yaml litellm/tests/config_*.yaml
litellm/tests/langfuse.log litellm/tests/langfuse.log
litellm/tests/test_custom_logger.py
litellm/tests/langfuse.log

View file

@ -35,6 +35,8 @@ def test_caching_v2(): # test in memory cache
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
litellm.success_callback = []
litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
@ -58,6 +60,8 @@ def test_caching_with_models_v2():
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response2: {response2}") print(f"response2: {response2}")
@ -91,6 +95,8 @@ def test_embedding_caching():
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
@ -145,6 +151,8 @@ def test_embedding_caching_azure():
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
@ -175,6 +183,8 @@ def test_redis_cache_completion():
print("\nresponse 3", response3) print("\nresponse 3", response3)
print("\nresponse 4", response4) print("\nresponse 4", response4)
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
""" """
1 & 2 should be exactly the same 1 & 2 should be exactly the same
@ -226,6 +236,8 @@ def test_redis_cache_completion_stream():
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
litellm.success_callback = [] litellm.success_callback = []
@ -271,10 +283,12 @@ def test_redis_cache_acompletion_stream():
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
test_redis_cache_acompletion_stream() # test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock(): def test_redis_cache_acompletion_stream_bedrock():
import asyncio import asyncio
@ -310,6 +324,8 @@ def test_redis_cache_acompletion_stream_bedrock():
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
@ -350,6 +366,8 @@ def test_custom_redis_cache_with_key():
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# test_custom_redis_cache_with_key() # test_custom_redis_cache_with_key()
@ -371,6 +389,8 @@ def test_custom_redis_cache_params():
print(litellm.cache.cache.redis_client) print(litellm.cache.cache.redis_client)
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred:", e) pytest.fail(f"Error occurred:", e)

View file

@ -8,8 +8,8 @@ import litellm
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
async_success = False async_success = False
complete_streaming_response_in_callback = ""
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
complete_streaming_response_in_callback = ""
def __init__(self): def __init__(self):
self.success: bool = False # type: ignore self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore self.failure: bool = False # type: ignore
@ -72,19 +72,20 @@ class MyCustomHandler(CustomLogger):
self.async_completion_kwargs_fail = kwargs self.async_completion_kwargs_fail = kwargs
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time): class TmpFunction:
global async_success, complete_streaming_response_in_callback complete_streaming_response_in_callback = ""
print(f"ON ASYNC LOGGING") async_success: bool = False
async_success = True async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
print("\nKWARGS", kwargs) print(f"ON ASYNC LOGGING")
complete_streaming_response_in_callback = kwargs.get("complete_streaming_response") self.async_success = True
self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
def test_async_chat_openai_stream(): def test_async_chat_openai_stream():
try: try:
global complete_streaming_response_in_callback tmp_function = TmpFunction()
# litellm.set_verbose = True # litellm.set_verbose = True
litellm.success_callback = [async_test_logging_fn] litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = "" complete_streaming_response = ""
async def call_gpt(): async def call_gpt():
nonlocal complete_streaming_response nonlocal complete_streaming_response
@ -98,12 +99,23 @@ def test_async_chat_openai_stream():
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response) print(complete_streaming_response)
asyncio.run(call_gpt()) asyncio.run(call_gpt())
assert complete_streaming_response_in_callback["choices"][0]["message"]["content"] == complete_streaming_response complete_streaming_response = complete_streaming_response.strip("'")
assert async_success == True print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
print(f"complete_streaming_response: {complete_streaming_response}")
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
response2 = complete_streaming_response
assert [ord(c) for c in response1] == [ord(c) for c in response2]
assert tmp_function.async_success == True
except Exception as e: except Exception as e:
print(e) print(e)
pytest.fail(f"An error occurred - {str(e)}") pytest.fail(f"An error occurred - {str(e)}")
test_async_chat_openai_stream() # test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure(): def test_completion_azure_stream_moderation_failure():
try: try: