fix(utils.py): fix reordering of items for cached embeddings

ensures cached embedding item is returned in correct order
This commit is contained in:
Krrish Dholakia 2024-04-08 12:17:57 -07:00
parent 28e4706bfd
commit 48bfc45cb0
3 changed files with 53 additions and 1 deletions

View file

@ -12,6 +12,10 @@ model_list:
api_version: "2023-07-01-preview"
stream_timeout: 0.001
model_name: azure-gpt-3.5
- model_name: text-embedding-ada-002
litellm_params:
model: text-embedding-ada-002
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-instruct
litellm_params:
model: gpt-3.5-turbo-instruct

View file

@ -345,6 +345,48 @@ async def test_embedding_caching_azure_individual_items():
assert embedding_val_2._hidden_params["cache_hit"] == True
@pytest.mark.asyncio
async def test_embedding_caching_azure_individual_items_reordered():
"""
Tests caching for individual items in an embedding list
- Cache an item
- call aembedding(..) with the item + 1 unique item
- compare to a 2nd aembedding (...) with 2 unique items
```
embedding_1 = ["hey how's it going", "I'm doing well"]
embedding_val_1 = embedding(...)
embedding_2 = ["hey how's it going", "I'm fine"]
embedding_val_2 = embedding(...)
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
```
"""
litellm.cache = Cache()
common_msg = f"{uuid.uuid4()}"
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg_2, common_msg]
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
]
embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True
)
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True
)
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
assert embedding_val_2._hidden_params["cache_hit"] == True
assert embedding_val_2.data[0]["embedding"] == embedding_val_1.data[1]["embedding"]
assert embedding_val_2.data[0]["index"] != embedding_val_1.data[1]["index"]
assert embedding_val_2.data[0]["index"] == 0
assert embedding_val_1.data[1]["index"] == 1
@pytest.mark.asyncio
async def test_redis_cache_basic():
"""

View file

@ -3174,7 +3174,13 @@ def client(original_function):
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = cr
final_embedding_cached_response.data[idx] = (
Embedding(
embedding=cr["embedding"],
index=idx,
object="embedding",
)
)
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True