fix(caching.py): fix async in-memory caching

This commit is contained in:
Krrish Dholakia 2024-01-13 15:33:57 +05:30
parent cdadac1649
commit 79cc739b53
3 changed files with 16 additions and 42 deletions

View file

@ -279,55 +279,20 @@ async def test_embedding_caching_azure_individual_items():
litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}"
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg]
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_3 = [
common_msg_2,
common_msg_2,
common_msg_2,
common_msg_2,
f"I'm fine {uuid.uuid4()}",
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
]
embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True
)
second_response_start_time = time.time()
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True
)
if embedding_val_2 is not None:
second_response_end_time = time.time()
second_response_time = second_response_end_time - second_response_start_time
third_response_start_time = time.time()
embedding_val_3 = await aembedding(
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
)
if embedding_val_3 is not None:
third_response_end_time = time.time()
third_response_time = third_response_end_time - third_response_start_time
print(f"second_response_time: {second_response_time}")
print(f"third_response_time: {third_response_time}")
assert (
second_response_time < third_response_time - 0.5
) # make sure it's actually faster
raise Exception(f"it works {second_response_time} < {third_response_time}")
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
assert embedding_val_2._hidden_params["cache_hit"] == True
@pytest.mark.asyncio
@ -369,7 +334,6 @@ async def test_redis_cache_basic():
)
print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id
raise Exception("it worked!")
def test_redis_cache_completion():