fix: add default credential for azure (#7095) (#7891)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s

* fix: add default credential for azure (#7095)

* fix: fix linting error

* fix: remove redundant test

* test: skip redundant test

---------

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2025-01-21 09:01:49 -08:00 committed by GitHub
parent c8aa876785
commit b81072d90c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 9 deletions

View file

@ -1,5 +1,7 @@
import os import os
from typing import Callable from typing import Callable, Union
from litellm._logging import verbose_logger
def get_azure_ad_token_provider() -> Callable[[], str]: def get_azure_ad_token_provider() -> Callable[[], str]:
@ -14,18 +16,29 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
Returns: Returns:
Callable that returns a temporary authentication token. Callable that returns a temporary authentication token.
""" """
from azure.identity import ClientSecretCredential, get_bearer_token_provider from azure.identity import (
ClientSecretCredential,
DefaultAzureCredential,
get_bearer_token_provider,
)
try: try:
credential = ClientSecretCredential( credential: Union[ClientSecretCredential, DefaultAzureCredential] = (
client_id=os.environ["AZURE_CLIENT_ID"], ClientSecretCredential(
client_secret=os.environ["AZURE_CLIENT_SECRET"], client_id=os.environ["AZURE_CLIENT_ID"],
tenant_id=os.environ["AZURE_TENANT_ID"], client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"],
)
) )
except KeyError as e: except KeyError as e:
raise ValueError( verbose_logger.exception(
"Missing environment variable required by Azure AD workflow." "Missing environment variable required by Azure AD workflow. "
) from e "DefaultAzureCredential will be used"
" {}".format(str(e))
)
credential = DefaultAzureCredential()
except Exception:
raise
return get_bearer_token_provider( return get_bearer_token_provider(
credential, credential,

View file

@ -84,6 +84,9 @@ async def test_router_init():
) )
@pytest.mark.skip(
reason="This test is not relevant to the current codebase. The default Azure AD workflow is used."
)
@patch("litellm.secret_managers.get_azure_ad_token_provider.os") @patch("litellm.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret( def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
mocked_os_lib: MagicMock, mocked_os_lib: MagicMock,