mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(utils.py): fix reordering of items for cached embeddings
ensures cached embedding item is returned in correct order
This commit is contained in:
parent
95debe0e6a
commit
075c96a408
3 changed files with 53 additions and 1 deletions
|
@ -345,6 +345,48 @@ async def test_embedding_caching_azure_individual_items():
|
|||
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_caching_azure_individual_items_reordered():
|
||||
"""
|
||||
Tests caching for individual items in an embedding list
|
||||
|
||||
- Cache an item
|
||||
- call aembedding(..) with the item + 1 unique item
|
||||
- compare to a 2nd aembedding (...) with 2 unique items
|
||||
```
|
||||
embedding_1 = ["hey how's it going", "I'm doing well"]
|
||||
embedding_val_1 = embedding(...)
|
||||
|
||||
embedding_2 = ["hey how's it going", "I'm fine"]
|
||||
embedding_val_2 = embedding(...)
|
||||
|
||||
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
|
||||
```
|
||||
"""
|
||||
litellm.cache = Cache()
|
||||
common_msg = f"{uuid.uuid4()}"
|
||||
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||
embedding_1 = [common_msg_2, common_msg]
|
||||
embedding_2 = [
|
||||
common_msg,
|
||||
f"I'm fine {uuid.uuid4()}",
|
||||
]
|
||||
|
||||
embedding_val_1 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||
)
|
||||
embedding_val_2 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||
)
|
||||
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
|
||||
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||
|
||||
assert embedding_val_2.data[0]["embedding"] == embedding_val_1.data[1]["embedding"]
|
||||
assert embedding_val_2.data[0]["index"] != embedding_val_1.data[1]["index"]
|
||||
assert embedding_val_2.data[0]["index"] == 0
|
||||
assert embedding_val_1.data[1]["index"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_basic():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue