test(test_embeddings.py): fix test

This commit is contained in:
Krrish Dholakia 2024-08-28 07:50:44 -07:00
parent 3df1186d72
commit 3cec00939e
2 changed files with 30 additions and 18 deletions

View file

@ -3366,7 +3366,10 @@ def embedding(
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = ( api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION") api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or litellm.AZURE_DEFAULT_API_VERSION
) )
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(

View file

@ -226,31 +226,40 @@ def test_openai_azure_embedding_with_oidc_and_cf():
os.environ["AZURE_API_KEY"] = old_key os.environ["AZURE_API_KEY"] = old_key
def test_openai_azure_embedding_optional_arg(mocker): def _openai_mock_response(*args, **kwargs):
mocked_create_embeddings = mocker.patch.object( new_response = MagicMock()
openai.resources.embeddings.Embeddings, new_response.headers = {"hello": "world"}
"create",
return_value=openai.types.create_embedding_response.CreateEmbeddingResponse( new_response.parse.return_value = (
openai.types.create_embedding_response.CreateEmbeddingResponse(
data=[], data=[],
model="azure/test", model="azure/test",
object="list", object="list",
usage=openai.types.create_embedding_response.Usage( usage=openai.types.create_embedding_response.Usage(
prompt_tokens=1, total_tokens=2 prompt_tokens=1, total_tokens=2
), ),
), )
)
_ = litellm.embedding(
model="azure/test",
input=["test"],
api_version="test",
api_base="test",
azure_ad_token="test",
) )
return new_response
assert mocked_create_embeddings.called_once_with(
model="test", input=["test"], timeout=600 def test_openai_azure_embedding_optional_arg():
)
assert "azure_ad_token" not in mocked_create_embeddings.call_args.kwargs with patch.object(
openai.resources.embeddings.Embeddings,
"create",
side_effect=_openai_mock_response,
) as mock_client:
_ = litellm.embedding(
model="azure/test",
input=["test"],
api_version="test",
api_base="test",
azure_ad_token="test",
)
assert mock_client.called_once_with(model="test", input=["test"], timeout=600)
assert "azure_ad_token" not in mock_client.call_args.kwargs
# test_openai_azure_embedding() # test_openai_azure_embedding()