test(test_caching.py): add disk cache test when using completion

This commit is contained in:
Antonio Loison 2024-04-24 15:59:22 +02:00
parent 34f4f719aa
commit bca84d46b1

View file

@ -875,6 +875,66 @@ async def test_redis_cache_acompletion_stream_bedrock():
print(e)
raise e
def test_disk_cache_completion():
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="disk",
)
print("test2 for Redis Caching - non streaming")
response1 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response2 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response3 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
"""
1 & 2 should be exactly the same
1 & 3 should be different, since input params are diff
1 & 4 should be diff, since models are diff
"""
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like seed, max_tokens are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
f" occurred:"
)
assert response1.id == response2.id
assert response1.created == response2.created
assert response1.choices[0].message.content == response2.choices[0].message.content
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio