mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test(test_embeddings.py): fix test
This commit is contained in:
parent
3df1186d72
commit
3cec00939e
2 changed files with 30 additions and 18 deletions
|
@ -3366,7 +3366,10 @@ def embedding(
|
||||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
|
|
||||||
api_version = (
|
api_version = (
|
||||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
or litellm.AZURE_DEFAULT_API_VERSION
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||||
|
|
|
@ -226,31 +226,40 @@ def test_openai_azure_embedding_with_oidc_and_cf():
|
||||||
os.environ["AZURE_API_KEY"] = old_key
|
os.environ["AZURE_API_KEY"] = old_key
|
||||||
|
|
||||||
|
|
||||||
def test_openai_azure_embedding_optional_arg(mocker):
|
def _openai_mock_response(*args, **kwargs):
|
||||||
mocked_create_embeddings = mocker.patch.object(
|
new_response = MagicMock()
|
||||||
openai.resources.embeddings.Embeddings,
|
new_response.headers = {"hello": "world"}
|
||||||
"create",
|
|
||||||
return_value=openai.types.create_embedding_response.CreateEmbeddingResponse(
|
new_response.parse.return_value = (
|
||||||
|
openai.types.create_embedding_response.CreateEmbeddingResponse(
|
||||||
data=[],
|
data=[],
|
||||||
model="azure/test",
|
model="azure/test",
|
||||||
object="list",
|
object="list",
|
||||||
usage=openai.types.create_embedding_response.Usage(
|
usage=openai.types.create_embedding_response.Usage(
|
||||||
prompt_tokens=1, total_tokens=2
|
prompt_tokens=1, total_tokens=2
|
||||||
),
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
_ = litellm.embedding(
|
|
||||||
model="azure/test",
|
|
||||||
input=["test"],
|
|
||||||
api_version="test",
|
|
||||||
api_base="test",
|
|
||||||
azure_ad_token="test",
|
|
||||||
)
|
)
|
||||||
|
return new_response
|
||||||
|
|
||||||
assert mocked_create_embeddings.called_once_with(
|
|
||||||
model="test", input=["test"], timeout=600
|
def test_openai_azure_embedding_optional_arg():
|
||||||
)
|
|
||||||
assert "azure_ad_token" not in mocked_create_embeddings.call_args.kwargs
|
with patch.object(
|
||||||
|
openai.resources.embeddings.Embeddings,
|
||||||
|
"create",
|
||||||
|
side_effect=_openai_mock_response,
|
||||||
|
) as mock_client:
|
||||||
|
_ = litellm.embedding(
|
||||||
|
model="azure/test",
|
||||||
|
input=["test"],
|
||||||
|
api_version="test",
|
||||||
|
api_base="test",
|
||||||
|
azure_ad_token="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.called_once_with(model="test", input=["test"], timeout=600)
|
||||||
|
assert "azure_ad_token" not in mock_client.call_args.kwargs
|
||||||
|
|
||||||
|
|
||||||
# test_openai_azure_embedding()
|
# test_openai_azure_embedding()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue