mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix: add default credential for azure (#7095) (#7891)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
* fix: add default credential for azure (#7095) * fix: fix linting error * fix: remove redundant test * test: skip redundant test --------- Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
parent
c8aa876785
commit
b81072d90c
2 changed files with 25 additions and 9 deletions
|
@ -1,5 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Callable
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
def get_azure_ad_token_provider() -> Callable[[], str]:
|
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||||
|
@ -14,18 +16,29 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||||
Returns:
|
Returns:
|
||||||
Callable that returns a temporary authentication token.
|
Callable that returns a temporary authentication token.
|
||||||
"""
|
"""
|
||||||
from azure.identity import ClientSecretCredential, get_bearer_token_provider
|
from azure.identity import (
|
||||||
|
ClientSecretCredential,
|
||||||
|
DefaultAzureCredential,
|
||||||
|
get_bearer_token_provider,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credential = ClientSecretCredential(
|
credential: Union[ClientSecretCredential, DefaultAzureCredential] = (
|
||||||
|
ClientSecretCredential(
|
||||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||||
)
|
)
|
||||||
|
)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError(
|
verbose_logger.exception(
|
||||||
"Missing environment variable required by Azure AD workflow. "
|
"Missing environment variable required by Azure AD workflow. "
|
||||||
) from e
|
"DefaultAzureCredential will be used"
|
||||||
|
" {}".format(str(e))
|
||||||
|
)
|
||||||
|
credential = DefaultAzureCredential()
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
return get_bearer_token_provider(
|
return get_bearer_token_provider(
|
||||||
credential,
|
credential,
|
||||||
|
|
|
@ -84,6 +84,9 @@ async def test_router_init():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="This test is not relevant to the current codebase. The default Azure AD workflow is used."
|
||||||
|
)
|
||||||
@patch("litellm.secret_managers.get_azure_ad_token_provider.os")
|
@patch("litellm.secret_managers.get_azure_ad_token_provider.os")
|
||||||
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
|
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
|
||||||
mocked_os_lib: MagicMock,
|
mocked_os_lib: MagicMock,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue