diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 90209a9e6..8d0afd067 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,6 +12,10 @@ model_list: api_version: "2023-07-01-preview" stream_timeout: 0.001 model_name: azure-gpt-3.5 +- model_name: text-embedding-ada-002 + litellm_params: + model: text-embedding-ada-002 + api_key: os.environ/OPENAI_API_KEY - model_name: gpt-instruct litellm_params: model: gpt-3.5-turbo-instruct diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index da46359d9..835d3611b 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -345,6 +345,48 @@ async def test_embedding_caching_azure_individual_items(): assert embedding_val_2._hidden_params["cache_hit"] == True +@pytest.mark.asyncio +async def test_embedding_caching_azure_individual_items_reordered(): + """ + Tests caching for individual items in an embedding list + + - Cache an item + - call aembedding(..) with the item + 1 unique item + - compare to a 2nd aembedding (...) with 2 unique items + ``` + embedding_1 = ["hey how's it going", "I'm doing well"] + embedding_val_1 = embedding(...) + + embedding_2 = ["hey how's it going", "I'm fine"] + embedding_val_2 = embedding(...) + + assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"] + ``` + """ + litellm.cache = Cache() + common_msg = f"{uuid.uuid4()}" + common_msg_2 = f"hey how's it going {uuid.uuid4()}" + embedding_1 = [common_msg_2, common_msg] + embedding_2 = [ + common_msg, + f"I'm fine {uuid.uuid4()}", + ] + + embedding_val_1 = await aembedding( + model="azure/azure-embedding-model", input=embedding_1, caching=True + ) + embedding_val_2 = await aembedding( + model="azure/azure-embedding-model", input=embedding_2, caching=True + ) + print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}") + assert embedding_val_2._hidden_params["cache_hit"] == True + + assert embedding_val_2.data[0]["embedding"] == embedding_val_1.data[1]["embedding"] + assert embedding_val_2.data[0]["index"] != embedding_val_1.data[1]["index"] + assert embedding_val_2.data[0]["index"] == 0 + assert embedding_val_1.data[1]["index"] == 1 + + @pytest.mark.asyncio async def test_redis_cache_basic(): """ diff --git a/litellm/utils.py b/litellm/utils.py index 6a58d56db..affa811f3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3174,7 +3174,13 @@ def client(original_function): for val in non_null_list: idx, cr = val # (idx, cr) tuple if cr is not None: - final_embedding_cached_response.data[idx] = cr + final_embedding_cached_response.data[idx] = ( + Embedding( + embedding=cr["embedding"], + index=idx, + object="embedding", + ) + ) if len(remaining_list) == 0: # LOG SUCCESS cache_hit = True