diff --git a/litellm/main.py b/litellm/main.py index a77a03522a..7eac01ff6f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3366,7 +3366,10 @@ def embedding( api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + or litellm.AZURE_DEFAULT_API_VERSION ) azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 31268395f1..54c823e4dc 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -226,31 +226,40 @@ def test_openai_azure_embedding_with_oidc_and_cf(): os.environ["AZURE_API_KEY"] = old_key -def test_openai_azure_embedding_optional_arg(mocker): - mocked_create_embeddings = mocker.patch.object( - openai.resources.embeddings.Embeddings, - "create", - return_value=openai.types.create_embedding_response.CreateEmbeddingResponse( +def _openai_mock_response(*args, **kwargs): + new_response = MagicMock() + new_response.headers = {"hello": "world"} + + new_response.parse.return_value = ( + openai.types.create_embedding_response.CreateEmbeddingResponse( data=[], model="azure/test", object="list", usage=openai.types.create_embedding_response.Usage( prompt_tokens=1, total_tokens=2 ), - ), - ) - _ = litellm.embedding( - model="azure/test", - input=["test"], - api_version="test", - api_base="test", - azure_ad_token="test", + ) ) + return new_response - assert mocked_create_embeddings.called_once_with( - model="test", input=["test"], timeout=600 - ) - assert "azure_ad_token" not in mocked_create_embeddings.call_args.kwargs + +def test_openai_azure_embedding_optional_arg(): + + with patch.object( + openai.resources.embeddings.Embeddings, + "create", + side_effect=_openai_mock_response, + ) as mock_client: + _ = litellm.embedding( + model="azure/test", + input=["test"], + api_version="test", + api_base="test", + azure_ad_token="test", + ) + + assert mock_client.called_once_with(model="test", input=["test"], timeout=600) + assert "azure_ad_token" not in mock_client.call_args.kwargs # test_openai_azure_embedding()