fix: add default credential for azure (#7095) (#7891)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s

* fix: add default credential for azure (#7095)

* fix: fix linting error

* fix: remove redundant test

* test: skip redundant test

---------

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2025-01-21 09:01:49 -08:00 committed by GitHub
parent c8aa876785
commit b81072d90c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 9 deletions

View file

@ -1,5 +1,7 @@
import os
from typing import Callable
from typing import Callable, Union
from litellm._logging import verbose_logger
def get_azure_ad_token_provider() -> Callable[[], str]:
@ -14,18 +16,29 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
Returns:
Callable that returns a temporary authentication token.
"""
from azure.identity import ClientSecretCredential, get_bearer_token_provider
from azure.identity import (
ClientSecretCredential,
DefaultAzureCredential,
get_bearer_token_provider,
)
try:
credential = ClientSecretCredential(
credential: Union[ClientSecretCredential, DefaultAzureCredential] = (
ClientSecretCredential(
client_id=os.environ["AZURE_CLIENT_ID"],
client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"],
)
)
except KeyError as e:
raise ValueError(
"Missing environment variable required by Azure AD workflow."
) from e
verbose_logger.exception(
"Missing environment variable required by Azure AD workflow. "
"DefaultAzureCredential will be used"
" {}".format(str(e))
)
credential = DefaultAzureCredential()
except Exception:
raise
return get_bearer_token_provider(
credential,

View file

@ -84,6 +84,9 @@ async def test_router_init():
)
@pytest.mark.skip(
reason="This test is not relevant to the current codebase. The default Azure AD workflow is used."
)
@patch("litellm.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
mocked_os_lib: MagicMock,