forked from phoenix/litellm-mirror
Fixed azure ad token not being processed properly in embedding models
This commit is contained in:
parent
0d18f3c0ca
commit
57ebb9582e
4 changed files with 535 additions and 383 deletions
|
@ -2606,7 +2606,7 @@ def embedding(
|
||||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
|
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||||
"AZURE_AD_TOKEN"
|
"AZURE_AD_TOKEN"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -191,6 +191,33 @@ def test_openai_azure_embedding():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_azure_embedding_optional_arg(mocker):
|
||||||
|
mocked_create_embeddings = mocker.patch.object(
|
||||||
|
openai.resources.embeddings.Embeddings,
|
||||||
|
"create",
|
||||||
|
return_value=openai.types.create_embedding_response.CreateEmbeddingResponse(
|
||||||
|
data=[],
|
||||||
|
model="azure/test",
|
||||||
|
object="list",
|
||||||
|
usage=openai.types.create_embedding_response.Usage(
|
||||||
|
prompt_tokens=1, total_tokens=2
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_ = litellm.embedding(
|
||||||
|
model="azure/test",
|
||||||
|
input=["test"],
|
||||||
|
api_version="test",
|
||||||
|
api_base="test",
|
||||||
|
azure_ad_token="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mocked_create_embeddings.called_once_with(
|
||||||
|
model="test", input=["test"], timeout=600
|
||||||
|
)
|
||||||
|
assert "azure_ad_token" not in mocked_create_embeddings.call_args.kwargs
|
||||||
|
|
||||||
|
|
||||||
# test_openai_azure_embedding()
|
# test_openai_azure_embedding()
|
||||||
|
|
||||||
# test_openai_embedding()
|
# test_openai_embedding()
|
||||||
|
|
888
poetry.lock
generated
888
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -70,6 +70,7 @@ litellm = 'litellm:run_server'
|
||||||
flake8 = "^6.1.0"
|
flake8 = "^6.1.0"
|
||||||
black = "^23.12.0"
|
black = "^23.12.0"
|
||||||
pytest = "^7.4.3"
|
pytest = "^7.4.3"
|
||||||
|
pytest-mock = "^3.12.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core", "wheel"]
|
requires = ["poetry-core", "wheel"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue