diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 7c213abf9..bdf2fc661 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -279,6 +279,24 @@ class AzureChatCompletion(BaseLLM): async for transformed_chunk in streamwrapper: yield transformed_chunk + async def aembedding( + self, + data: dict, + model_response: ModelResponse, + azure_client_params: dict, + client=None, + ): + response = None + try: + if client is None: + openai_aclient = AsyncAzureOpenAI(**azure_client_params) + else: + openai_aclient = client + response = await openai_aclient.embeddings.create(**data) + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") + except Exception as e: + raise e + def embedding(self, model: str, input: list, @@ -290,7 +308,8 @@ class AzureChatCompletion(BaseLLM): model_response=None, optional_params=None, azure_ad_token: Optional[str]=None, - client = None + client = None, + aembedding=None, ): super().embedding() exception_mapping_worked = False @@ -319,6 +338,9 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token + if aembedding == True: + response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params) + return response if client is None: azure_client = AzureOpenAI(**azure_client_params) # type: ignore else: diff --git a/litellm/main.py b/litellm/main.py index 61552eead..ccac102b3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1801,7 +1801,8 @@ def embedding( timeout=timeout, model_response=EmbeddingResponse(), optional_params=optional_params, - client=client + client=client, + aembedding=aembedding ) elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai": api_base = ( diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 782c1728c..dc648cb92 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -192,7 +192,26 @@ def test_aembedding(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_aembedding() +# test_aembedding() + + +def test_aembedding_azure(): + try: + import asyncio + async def embedding_call(): + try: + response = await litellm.aembedding( + model="azure/azure-embedding-model", + input=["good morning from litellm", "this is another item"] + ) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + asyncio.run(embedding_call()) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +# test_aembedding_azure() # def test_custom_openai_embedding(): # litellm.set_verbose=True