feat: prioritize api_key over tenant_id for more Azure AD token provi… (#8701)

* feat: prioritize api_key over tenant_id for more Azure AD token provider (#8318)

* fix: prioritize api_key over tenant_id for Azure AD token provider

* test: Add test for Azure AD token provider in router

* fix: fix linting error

---------

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2025-03-09 18:59:37 -07:00 committed by GitHub
parent e00d4fb18c
commit 65ef65d360
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 54 additions and 31 deletions

View file

@ -195,7 +195,8 @@ class InitalizeOpenAISDKClient:
organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"):
# If we have api_key, then we have higher priority
if not api_key and litellm_params.get("tenant_id"):
verbose_router_logger.debug(
"Using Azure AD Token Provider for Azure Auth"
)
@ -232,7 +233,7 @@ class InitalizeOpenAISDKClient:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif (
azure_ad_token_provider is None
not api_key and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
try:

View file

@ -1,7 +1,5 @@
import os
from typing import Callable, Union
from litellm._logging import verbose_logger
from typing import Callable
def get_azure_ad_token_provider() -> Callable[[], str]:
@ -16,31 +14,23 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
Returns:
Callable that returns a temporary authentication token.
"""
from azure.identity import (
ClientSecretCredential,
DefaultAzureCredential,
get_bearer_token_provider,
)
import azure.identity as identity
from azure.identity import get_bearer_token_provider
try:
credential: Union[ClientSecretCredential, DefaultAzureCredential] = (
ClientSecretCredential(
client_id=os.environ["AZURE_CLIENT_ID"],
client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"],
)
)
except KeyError as e:
verbose_logger.exception(
"Missing environment variable required by Azure AD workflow. "
"DefaultAzureCredential will be used"
" {}".format(str(e))
)
credential = DefaultAzureCredential()
except Exception:
raise
return get_bearer_token_provider(
credential,
"https://cognitiveservices.azure.com/.default",
azure_scope = os.environ.get(
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
)
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
cred_cls = getattr(identity, cred)
# ClientSecretCredential, DefaultAzureCredential, AzureCliCredential
if cred == "ClientSecretCredential":
credential = cred_cls(
client_id=os.environ["AZURE_CLIENT_ID"],
client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"],
)
else:
credential = cred_cls()
return get_bearer_token_provider(credential, azure_scope)

View file

@ -219,6 +219,38 @@ def test_router_azure_ai_client_init():
assert not isinstance(_client, AsyncAzureOpenAI)
def test_router_azure_ad_token_provider():
_deployment = {
"model_name": "gpt-4o_2024-05-13",
"litellm_params": {
"model": "azure/gpt-4o_2024-05-13",
"api_base": "my-fake-route",
"api_version": "2024-08-01-preview",
},
"model_info": {"id": "1234"},
}
for azure_cred in ["DefaultAzureCredential", "AzureCliCredential"]:
os.environ["AZURE_CREDENTIAL"] = azure_cred
litellm.enable_azure_ad_token_refresh = True
router = Router(model_list=[_deployment])
_client = router._get_client(
deployment=_deployment,
client_type="async",
kwargs={"stream": False},
)
print(_client)
import azure.identity as identity
from openai import AsyncAzureOpenAI, AsyncOpenAI
assert isinstance(_client, AsyncOpenAI)
assert isinstance(_client, AsyncAzureOpenAI)
assert _client._azure_ad_token_provider is not None
assert isinstance(_client._azure_ad_token_provider.__closure__, tuple)
assert isinstance(_client._azure_ad_token_provider.__closure__[0].cell_contents._credential,
getattr(identity, os.environ["AZURE_CREDENTIAL"]))
def test_router_sensitive_keys():
try:
router = Router(