forked from phoenix/litellm-mirror
fix: support async redis caching
This commit is contained in:
parent
817a3d29b7
commit
007870390d
6 changed files with 357 additions and 122 deletions
|
@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items():
|
|||
"""
|
||||
Tests caching for individual items in an embedding list
|
||||
|
||||
Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls
|
||||
|
||||
- Cache an item
|
||||
- call aembedding(..) with the item + 1 unique item
|
||||
- compare to a 2nd aembedding (...) with 2 unique items
|
||||
```
|
||||
embedding_1 = ["hey how's it going", "I'm doing well"]
|
||||
embedding_val_1 = embedding(...)
|
||||
|
@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items():
|
|||
"""
|
||||
litellm.cache = Cache()
|
||||
common_msg = f"hey how's it going {uuid.uuid4()}"
|
||||
embedding_1 = [common_msg, "I'm doing well"]
|
||||
embedding_2 = [common_msg, "I'm fine"]
|
||||
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||
embedding_2 = [
|
||||
common_msg,
|
||||
f"I'm fine {uuid.uuid4()}",
|
||||
common_msg,
|
||||
common_msg,
|
||||
common_msg,
|
||||
] * 20
|
||||
embedding_2 = [
|
||||
common_msg,
|
||||
f"I'm fine {uuid.uuid4()}",
|
||||
common_msg,
|
||||
common_msg,
|
||||
common_msg,
|
||||
] * 20
|
||||
embedding_3 = [
|
||||
common_msg_2,
|
||||
common_msg_2,
|
||||
common_msg_2,
|
||||
common_msg_2,
|
||||
f"I'm fine {uuid.uuid4()}",
|
||||
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
|
||||
|
||||
embedding_val_1 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||
)
|
||||
|
||||
second_response_start_time = time.time()
|
||||
embedding_val_2 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||
)
|
||||
print(f"embedding_val_2: {embedding_val_2}")
|
||||
if (
|
||||
embedding_val_2["data"][0]["embedding"]
|
||||
!= embedding_val_1["data"][0]["embedding"]
|
||||
):
|
||||
print(f"embedding1: {embedding_val_1}")
|
||||
print(f"embedding2: {embedding_val_2}")
|
||||
pytest.fail("Error occurred: Embedding caching failed")
|
||||
if (
|
||||
embedding_val_2["data"][1]["embedding"]
|
||||
== embedding_val_1["data"][1]["embedding"]
|
||||
):
|
||||
print(f"embedding1: {embedding_val_1}")
|
||||
print(f"embedding2: {embedding_val_2}")
|
||||
pytest.fail("Error occurred: Embedding caching failed")
|
||||
if embedding_val_2 is not None:
|
||||
second_response_end_time = time.time()
|
||||
second_response_time = second_response_end_time - second_response_start_time
|
||||
|
||||
third_response_start_time = time.time()
|
||||
embedding_val_3 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
|
||||
)
|
||||
if embedding_val_3 is not None:
|
||||
third_response_end_time = time.time()
|
||||
third_response_time = third_response_end_time - third_response_start_time
|
||||
|
||||
print(f"second_response_time: {second_response_time}")
|
||||
print(f"third_response_time: {third_response_time}")
|
||||
|
||||
assert (
|
||||
second_response_time < third_response_time - 0.5
|
||||
) # make sure it's actually faster
|
||||
raise Exception(f"it works {second_response_time} < {third_response_time}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_basic():
|
||||
"""
|
||||
Init redis client
|
||||
- write to client
|
||||
- read from client
|
||||
"""
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"cache_key: {cache_key}")
|
||||
litellm.cache.add_cache(result=response1, cache_key=cache_key)
|
||||
print(f"cache key pre async get: {cache_key}")
|
||||
stored_val = await litellm.cache.async_get_cache(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"stored_val: {stored_val}")
|
||||
assert stored_val["id"] == response1.id
|
||||
raise Exception("it worked!")
|
||||
|
||||
|
||||
def test_redis_cache_completion():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue