mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
feat: prioritize api_key over tenant_id for more Azure AD token provi… (#8701)
* feat: prioritize api_key over tenant_id for more Azure AD token provider (#8318) * fix: prioritize api_key over tenant_id for Azure AD token provider * test: Add test for Azure AD token provider in router * fix: fix linting error --------- Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
parent
e00d4fb18c
commit
65ef65d360
3 changed files with 54 additions and 31 deletions
|
@ -195,7 +195,8 @@ class InitalizeOpenAISDKClient:
|
|||
organization = get_secret_str(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
if litellm_params.get("tenant_id"):
|
||||
# If we have api_key, then we have higher priority
|
||||
if not api_key and litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug(
|
||||
"Using Azure AD Token Provider for Azure Auth"
|
||||
)
|
||||
|
@ -232,7 +233,7 @@ class InitalizeOpenAISDKClient:
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
azure_ad_token_provider is None
|
||||
not api_key and azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
try:
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import os
|
||||
from typing import Callable, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||
|
@ -16,31 +14,23 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
|
|||
Returns:
|
||||
Callable that returns a temporary authentication token.
|
||||
"""
|
||||
from azure.identity import (
|
||||
ClientSecretCredential,
|
||||
DefaultAzureCredential,
|
||||
get_bearer_token_provider,
|
||||
)
|
||||
import azure.identity as identity
|
||||
from azure.identity import get_bearer_token_provider
|
||||
|
||||
try:
|
||||
credential: Union[ClientSecretCredential, DefaultAzureCredential] = (
|
||||
ClientSecretCredential(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
verbose_logger.exception(
|
||||
"Missing environment variable required by Azure AD workflow. "
|
||||
"DefaultAzureCredential will be used"
|
||||
" {}".format(str(e))
|
||||
)
|
||||
credential = DefaultAzureCredential()
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
return get_bearer_token_provider(
|
||||
credential,
|
||||
"https://cognitiveservices.azure.com/.default",
|
||||
azure_scope = os.environ.get(
|
||||
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
|
||||
|
||||
cred_cls = getattr(identity, cred)
|
||||
# ClientSecretCredential, DefaultAzureCredential, AzureCliCredential
|
||||
if cred == "ClientSecretCredential":
|
||||
credential = cred_cls(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
)
|
||||
else:
|
||||
credential = cred_cls()
|
||||
|
||||
return get_bearer_token_provider(credential, azure_scope)
|
||||
|
|
|
@ -219,6 +219,38 @@ def test_router_azure_ai_client_init():
|
|||
assert not isinstance(_client, AsyncAzureOpenAI)
|
||||
|
||||
|
||||
def test_router_azure_ad_token_provider():
|
||||
_deployment = {
|
||||
"model_name": "gpt-4o_2024-05-13",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-4o_2024-05-13",
|
||||
"api_base": "my-fake-route",
|
||||
"api_version": "2024-08-01-preview",
|
||||
},
|
||||
"model_info": {"id": "1234"},
|
||||
}
|
||||
for azure_cred in ["DefaultAzureCredential", "AzureCliCredential"]:
|
||||
os.environ["AZURE_CREDENTIAL"] = azure_cred
|
||||
litellm.enable_azure_ad_token_refresh = True
|
||||
router = Router(model_list=[_deployment])
|
||||
|
||||
_client = router._get_client(
|
||||
deployment=_deployment,
|
||||
client_type="async",
|
||||
kwargs={"stream": False},
|
||||
)
|
||||
print(_client)
|
||||
import azure.identity as identity
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
assert isinstance(_client, AsyncOpenAI)
|
||||
assert isinstance(_client, AsyncAzureOpenAI)
|
||||
assert _client._azure_ad_token_provider is not None
|
||||
assert isinstance(_client._azure_ad_token_provider.__closure__, tuple)
|
||||
assert isinstance(_client._azure_ad_token_provider.__closure__[0].cell_contents._credential,
|
||||
getattr(identity, os.environ["AZURE_CREDENTIAL"]))
|
||||
|
||||
|
||||
def test_router_sensitive_keys():
|
||||
try:
|
||||
router = Router(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue