test(test_embeddings.py): fix test

This commit is contained in:
Krrish Dholakia 2024-08-28 07:50:44 -07:00
parent 3df1186d72
commit 3cec00939e
2 changed files with 30 additions and 18 deletions

View file

@ -226,31 +226,40 @@ def test_openai_azure_embedding_with_oidc_and_cf():
os.environ["AZURE_API_KEY"] = old_key
def test_openai_azure_embedding_optional_arg(mocker):
mocked_create_embeddings = mocker.patch.object(
openai.resources.embeddings.Embeddings,
"create",
return_value=openai.types.create_embedding_response.CreateEmbeddingResponse(
def _openai_mock_response(*args, **kwargs):
new_response = MagicMock()
new_response.headers = {"hello": "world"}
new_response.parse.return_value = (
openai.types.create_embedding_response.CreateEmbeddingResponse(
data=[],
model="azure/test",
object="list",
usage=openai.types.create_embedding_response.Usage(
prompt_tokens=1, total_tokens=2
),
),
)
_ = litellm.embedding(
model="azure/test",
input=["test"],
api_version="test",
api_base="test",
azure_ad_token="test",
)
)
return new_response
assert mocked_create_embeddings.called_once_with(
model="test", input=["test"], timeout=600
)
assert "azure_ad_token" not in mocked_create_embeddings.call_args.kwargs
def test_openai_azure_embedding_optional_arg():
with patch.object(
openai.resources.embeddings.Embeddings,
"create",
side_effect=_openai_mock_response,
) as mock_client:
_ = litellm.embedding(
model="azure/test",
input=["test"],
api_version="test",
api_base="test",
azure_ad_token="test",
)
assert mock_client.called_once_with(model="test", input=["test"], timeout=600)
assert "azure_ad_token" not in mock_client.call_args.kwargs
# test_openai_azure_embedding()