mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* Azure Service Principal with Secret authentication workflow. (#5131) * Implement Azure Service Principal with Secret authentication workflow. * Use `ClientSecretCredential` instead of `DefaultAzureCredential`. * Move imports into the function. * Add type hint for `azure_ad_token_provider`. * Add unit test for router initialization and sample completion using Azure Service Principal with Secret authentication workflow. * Add unit test for router initialization with neither API key nor using Azure Service Principal with Secret authentication workflow. * fix(client_initializtion_utils.py): fix typing + overrides * test: fix linting errors * fix(client_initialization_utils.py): fix client init azure ad token logic * fix(router_client_initialization.py): add flag check for reading azure ad token from environment * test(test_streaming.py): skip end of life bedrock model * test(test_router_client_init.py): add correct flag to test --------- Co-authored-by: kzych-inpost <142029278+kzych-inpost@users.noreply.github.com>
This commit is contained in:
parent
2797b30a50
commit
02f288a8a3
4 changed files with 193 additions and 5 deletions
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
@ -9,6 +9,9 @@ import openai
|
|||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.proxy.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -172,7 +175,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = litellm.get_secret(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider = None
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
if litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
|
@ -197,6 +200,16 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
if azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||
except ValueError:
|
||||
verbose_router_logger.debug(
|
||||
"Azure AD Token Provider could not be used."
|
||||
)
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
|
@ -211,6 +224,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -236,6 +250,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -258,6 +273,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -283,6 +299,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -313,6 +330,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue