mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat: prioritize api_key over tenant_id for more Azure AD token provi… (#8701)
* feat: prioritize api_key over tenant_id for more Azure AD token provider (#8318) * fix: prioritize api_key over tenant_id for Azure AD token provider * test: Add test for Azure AD token provider in router * fix: fix linting error --------- Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
parent
e00d4fb18c
commit
65ef65d360
3 changed files with 54 additions and 31 deletions
|
@ -195,7 +195,8 @@ class InitalizeOpenAISDKClient:
|
||||||
organization = get_secret_str(organization_env_name)
|
organization = get_secret_str(organization_env_name)
|
||||||
litellm_params["organization"] = organization
|
litellm_params["organization"] = organization
|
||||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||||
if litellm_params.get("tenant_id"):
|
# If we have api_key, then we have higher priority
|
||||||
|
if not api_key and litellm_params.get("tenant_id"):
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
"Using Azure AD Token Provider for Azure Auth"
|
"Using Azure AD Token Provider for Azure Auth"
|
||||||
)
|
)
|
||||||
|
@ -232,7 +233,7 @@ class InitalizeOpenAISDKClient:
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
elif (
|
elif (
|
||||||
azure_ad_token_provider is None
|
not api_key and azure_ad_token_provider is None
|
||||||
and litellm.enable_azure_ad_token_refresh is True
|
and litellm.enable_azure_ad_token_refresh is True
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Union
|
from typing import Callable
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_azure_ad_token_provider() -> Callable[[], str]:
|
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||||
|
@ -16,31 +14,23 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||||
Returns:
|
Returns:
|
||||||
Callable that returns a temporary authentication token.
|
Callable that returns a temporary authentication token.
|
||||||
"""
|
"""
|
||||||
from azure.identity import (
|
import azure.identity as identity
|
||||||
ClientSecretCredential,
|
from azure.identity import get_bearer_token_provider
|
||||||
DefaultAzureCredential,
|
|
||||||
get_bearer_token_provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
azure_scope = os.environ.get(
|
||||||
credential: Union[ClientSecretCredential, DefaultAzureCredential] = (
|
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
|
||||||
ClientSecretCredential(
|
|
||||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
|
||||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
|
||||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except KeyError as e:
|
|
||||||
verbose_logger.exception(
|
|
||||||
"Missing environment variable required by Azure AD workflow. "
|
|
||||||
"DefaultAzureCredential will be used"
|
|
||||||
" {}".format(str(e))
|
|
||||||
)
|
|
||||||
credential = DefaultAzureCredential()
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return get_bearer_token_provider(
|
|
||||||
credential,
|
|
||||||
"https://cognitiveservices.azure.com/.default",
|
|
||||||
)
|
)
|
||||||
|
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
|
||||||
|
|
||||||
|
cred_cls = getattr(identity, cred)
|
||||||
|
# ClientSecretCredential, DefaultAzureCredential, AzureCliCredential
|
||||||
|
if cred == "ClientSecretCredential":
|
||||||
|
credential = cred_cls(
|
||||||
|
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||||
|
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||||
|
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
credential = cred_cls()
|
||||||
|
|
||||||
|
return get_bearer_token_provider(credential, azure_scope)
|
||||||
|
|
|
@ -219,6 +219,38 @@ def test_router_azure_ai_client_init():
|
||||||
assert not isinstance(_client, AsyncAzureOpenAI)
|
assert not isinstance(_client, AsyncAzureOpenAI)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_azure_ad_token_provider():
|
||||||
|
_deployment = {
|
||||||
|
"model_name": "gpt-4o_2024-05-13",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-4o_2024-05-13",
|
||||||
|
"api_base": "my-fake-route",
|
||||||
|
"api_version": "2024-08-01-preview",
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
}
|
||||||
|
for azure_cred in ["DefaultAzureCredential", "AzureCliCredential"]:
|
||||||
|
os.environ["AZURE_CREDENTIAL"] = azure_cred
|
||||||
|
litellm.enable_azure_ad_token_refresh = True
|
||||||
|
router = Router(model_list=[_deployment])
|
||||||
|
|
||||||
|
_client = router._get_client(
|
||||||
|
deployment=_deployment,
|
||||||
|
client_type="async",
|
||||||
|
kwargs={"stream": False},
|
||||||
|
)
|
||||||
|
print(_client)
|
||||||
|
import azure.identity as identity
|
||||||
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
|
|
||||||
|
assert isinstance(_client, AsyncOpenAI)
|
||||||
|
assert isinstance(_client, AsyncAzureOpenAI)
|
||||||
|
assert _client._azure_ad_token_provider is not None
|
||||||
|
assert isinstance(_client._azure_ad_token_provider.__closure__, tuple)
|
||||||
|
assert isinstance(_client._azure_ad_token_provider.__closure__[0].cell_contents._credential,
|
||||||
|
getattr(identity, os.environ["AZURE_CREDENTIAL"]))
|
||||||
|
|
||||||
|
|
||||||
def test_router_sensitive_keys():
|
def test_router_sensitive_keys():
|
||||||
try:
|
try:
|
||||||
router = Router(
|
router = Router(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue