fix: support async redis caching

This commit is contained in:
Krrish Dholakia 2024-01-12 21:46:41 +05:30
parent 817a3d29b7
commit 007870390d
6 changed files with 357 additions and 122 deletions

View file

@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items():
"""
Tests caching for individual items in an embedding list
Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls
- Cache an item
- call aembedding(..) with the item + 1 unique item
- compare to a 2nd aembedding (...) with 2 unique items
```
embedding_1 = ["hey how's it going", "I'm doing well"]
embedding_val_1 = embedding(...)
@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items():
"""
litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg, "I'm doing well"]
embedding_2 = [common_msg, "I'm fine"]
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_3 = [
common_msg_2,
common_msg_2,
common_msg_2,
common_msg_2,
f"I'm fine {uuid.uuid4()}",
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True
)
second_response_start_time = time.time()
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True
)
print(f"embedding_val_2: {embedding_val_2}")
if (
embedding_val_2["data"][0]["embedding"]
!= embedding_val_1["data"][0]["embedding"]
):
print(f"embedding1: {embedding_val_1}")
print(f"embedding2: {embedding_val_2}")
pytest.fail("Error occurred: Embedding caching failed")
if (
embedding_val_2["data"][1]["embedding"]
== embedding_val_1["data"][1]["embedding"]
):
print(f"embedding1: {embedding_val_1}")
print(f"embedding2: {embedding_val_2}")
pytest.fail("Error occurred: Embedding caching failed")
if embedding_val_2 is not None:
second_response_end_time = time.time()
second_response_time = second_response_end_time - second_response_start_time
third_response_start_time = time.time()
embedding_val_3 = await aembedding(
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
)
if embedding_val_3 is not None:
third_response_end_time = time.time()
third_response_time = third_response_end_time - third_response_start_time
print(f"second_response_time: {second_response_time}")
print(f"third_response_time: {third_response_time}")
assert (
second_response_time < third_response_time - 0.5
) # make sure it's actually faster
raise Exception(f"it works {second_response_time} < {third_response_time}")
@pytest.mark.asyncio
async def test_redis_cache_basic():
"""
Init redis client
- write to client
- read from client
"""
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
)
cache_key = litellm.cache.get_cache_key(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"cache_key: {cache_key}")
litellm.cache.add_cache(result=response1, cache_key=cache_key)
print(f"cache key pre async get: {cache_key}")
stored_val = await litellm.cache.async_get_cache(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id
raise Exception("it worked!")
def test_redis_cache_completion():