feat: prioritize api_key over tenant_id for more Azure AD token provi… (#8701)

* feat: prioritize api_key over tenant_id for more Azure AD token provider (#8318)

* fix: prioritize api_key over tenant_id for Azure AD token provider

* test: Add test for Azure AD token provider in router

* fix: fix linting error

---------

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2025-03-09 18:59:37 -07:00 committed by GitHub
parent e00d4fb18c
commit 65ef65d360
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 54 additions and 31 deletions

View file

@ -195,7 +195,8 @@ class InitalizeOpenAISDKClient:
organization = get_secret_str(organization_env_name) organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"): # If we have api_key, then we have higher priority
if not api_key and litellm_params.get("tenant_id"):
verbose_router_logger.debug( verbose_router_logger.debug(
"Using Azure AD Token Provider for Azure Auth" "Using Azure AD Token Provider for Azure Auth"
) )
@ -232,7 +233,7 @@ class InitalizeOpenAISDKClient:
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif ( elif (
azure_ad_token_provider is None not api_key and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True and litellm.enable_azure_ad_token_refresh is True
): ):
try: try:

View file

@ -1,7 +1,5 @@
import os import os
from typing import Callable, Union from typing import Callable
from litellm._logging import verbose_logger
def get_azure_ad_token_provider() -> Callable[[], str]: def get_azure_ad_token_provider() -> Callable[[], str]:
@ -16,31 +14,23 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
Returns: Returns:
Callable that returns a temporary authentication token. Callable that returns a temporary authentication token.
""" """
from azure.identity import ( import azure.identity as identity
ClientSecretCredential, from azure.identity import get_bearer_token_provider
DefaultAzureCredential,
get_bearer_token_provider,
)
try: azure_scope = os.environ.get(
credential: Union[ClientSecretCredential, DefaultAzureCredential] = ( "AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
ClientSecretCredential( )
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
cred_cls = getattr(identity, cred)
# ClientSecretCredential, DefaultAzureCredential, AzureCliCredential
if cred == "ClientSecretCredential":
credential = cred_cls(
client_id=os.environ["AZURE_CLIENT_ID"], client_id=os.environ["AZURE_CLIENT_ID"],
client_secret=os.environ["AZURE_CLIENT_SECRET"], client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"], tenant_id=os.environ["AZURE_TENANT_ID"],
) )
) else:
except KeyError as e: credential = cred_cls()
verbose_logger.exception(
"Missing environment variable required by Azure AD workflow. "
"DefaultAzureCredential will be used"
" {}".format(str(e))
)
credential = DefaultAzureCredential()
except Exception:
raise
return get_bearer_token_provider( return get_bearer_token_provider(credential, azure_scope)
credential,
"https://cognitiveservices.azure.com/.default",
)

View file

@ -219,6 +219,38 @@ def test_router_azure_ai_client_init():
assert not isinstance(_client, AsyncAzureOpenAI) assert not isinstance(_client, AsyncAzureOpenAI)
def test_router_azure_ad_token_provider():
_deployment = {
"model_name": "gpt-4o_2024-05-13",
"litellm_params": {
"model": "azure/gpt-4o_2024-05-13",
"api_base": "my-fake-route",
"api_version": "2024-08-01-preview",
},
"model_info": {"id": "1234"},
}
for azure_cred in ["DefaultAzureCredential", "AzureCliCredential"]:
os.environ["AZURE_CREDENTIAL"] = azure_cred
litellm.enable_azure_ad_token_refresh = True
router = Router(model_list=[_deployment])
_client = router._get_client(
deployment=_deployment,
client_type="async",
kwargs={"stream": False},
)
print(_client)
import azure.identity as identity
from openai import AsyncAzureOpenAI, AsyncOpenAI
assert isinstance(_client, AsyncOpenAI)
assert isinstance(_client, AsyncAzureOpenAI)
assert _client._azure_ad_token_provider is not None
assert isinstance(_client._azure_ad_token_provider.__closure__, tuple)
assert isinstance(_client._azure_ad_token_provider.__closure__[0].cell_contents._credential,
getattr(identity, os.environ["AZURE_CREDENTIAL"]))
def test_router_sensitive_keys(): def test_router_sensitive_keys():
try: try:
router = Router( router = Router(