diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 5294bd7141..0fc7370ebd 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -8,7 +8,6 @@ import httpx # type: ignore from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI import litellm -from litellm.caching.caching import DualCache from litellm.constants import DEFAULT_MAX_RETRIES from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.custom_httpx.http_handler import ( @@ -25,15 +24,16 @@ from litellm.types.utils import ( from litellm.utils import ( CustomStreamWrapper, convert_to_model_response_object, - get_secret, modify_url, ) from ...types.llms.openai import HttpxBinaryResponseContent from ..base import BaseLLM -from .common_utils import AzureOpenAIError, process_azure_headers - -azure_ad_cache = DualCache() +from .common_utils import ( + AzureOpenAIError, + get_azure_ad_token_from_oidc, + process_azure_headers, +) class AzureOpenAIAssistantsAPIConfig: @@ -110,81 +110,6 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): return azure_client_params -def get_azure_ad_token_from_oidc(azure_ad_token: str): - azure_client_id = os.getenv("AZURE_CLIENT_ID", None) - azure_tenant_id = os.getenv("AZURE_TENANT_ID", None) - azure_authority_host = os.getenv( - "AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com" - ) - - if azure_client_id is None or azure_tenant_id is None: - raise AzureOpenAIError( - status_code=422, - message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", - ) - - oidc_token = get_secret(azure_ad_token) - - if oidc_token is None: - raise AzureOpenAIError( - status_code=401, - message="OIDC token could not be retrieved from secret manager.", - ) - - azure_ad_token_cache_key = json.dumps( - { - "azure_client_id": azure_client_id, - "azure_tenant_id": azure_tenant_id, - "azure_authority_host": azure_authority_host, - "oidc_token": oidc_token, - } - ) - - azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key) - if azure_ad_token_access_token is not None: - return azure_ad_token_access_token - - client = litellm.module_level_client - req_token = client.post( - f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token", - data={ - "client_id": azure_client_id, - "grant_type": "client_credentials", - "scope": "https://cognitiveservices.azure.com/.default", - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": oidc_token, - }, - ) - - if req_token.status_code != 200: - raise AzureOpenAIError( - status_code=req_token.status_code, - message=req_token.text, - ) - - azure_ad_token_json = req_token.json() - azure_ad_token_access_token = azure_ad_token_json.get("access_token", None) - azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None) - - if azure_ad_token_access_token is None: - raise AzureOpenAIError( - status_code=422, message="Azure AD Token access_token not returned" - ) - - if azure_ad_token_expires_in is None: - raise AzureOpenAIError( - status_code=422, message="Azure AD Token expires_in not returned" - ) - - azure_ad_cache.set_cache( - key=azure_ad_token_cache_key, - value=azure_ad_token_access_token, - ttl=azure_ad_token_expires_in, - ) - - return azure_ad_token_access_token - - def _check_dynamic_azure_params( azure_client_params: dict, azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]], diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 2a96f5c39c..b2c61005ba 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -1,3 +1,5 @@ +import json +import os from typing import Callable, Optional, Union import httpx @@ -5,9 +7,16 @@ from openai import AsyncAzureOpenAI, AzureOpenAI import litellm from litellm._logging import verbose_logger +from litellm.caching.caching import DualCache from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.secret_managers.get_azure_ad_token_provider import ( + get_azure_ad_token_provider, +) +from litellm.secret_managers.get_secret import get_secret from litellm.secret_managers.main import get_secret_str +azure_ad_cache = DualCache() + class AzureOpenAIError(BaseLLMException): def __init__( @@ -178,3 +187,147 @@ def get_azure_ad_token_from_username_password( verbose_logger.debug("token_provider %s", token_provider) return token_provider + + +def get_azure_ad_token_from_oidc(azure_ad_token: str): + azure_client_id = os.getenv("AZURE_CLIENT_ID", None) + azure_tenant_id = os.getenv("AZURE_TENANT_ID", None) + azure_authority_host = os.getenv( + "AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com" + ) + + if azure_client_id is None or azure_tenant_id is None: + raise AzureOpenAIError( + status_code=422, + message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", + ) + + oidc_token = get_secret(azure_ad_token) + + if oidc_token is None: + raise AzureOpenAIError( + status_code=401, + message="OIDC token could not be retrieved from secret manager.", + ) + + azure_ad_token_cache_key = json.dumps( + { + "azure_client_id": azure_client_id, + "azure_tenant_id": azure_tenant_id, + "azure_authority_host": azure_authority_host, + "oidc_token": oidc_token, + } + ) + + azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key) + if azure_ad_token_access_token is not None: + return azure_ad_token_access_token + + client = litellm.module_level_client + req_token = client.post( + f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token", + data={ + "client_id": azure_client_id, + "grant_type": "client_credentials", + "scope": "https://cognitiveservices.azure.com/.default", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": oidc_token, + }, + ) + + if req_token.status_code != 200: + raise AzureOpenAIError( + status_code=req_token.status_code, + message=req_token.text, + ) + + azure_ad_token_json = req_token.json() + azure_ad_token_access_token = azure_ad_token_json.get("access_token", None) + azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None) + + if azure_ad_token_access_token is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token access_token not returned" + ) + + if azure_ad_token_expires_in is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token expires_in not returned" + ) + + azure_ad_cache.set_cache( + key=azure_ad_token_cache_key, + value=azure_ad_token_access_token, + ttl=azure_ad_token_expires_in, + ) + + return azure_ad_token_access_token + + +def initialize_azure_sdk_client( + litellm_params: dict, + api_key: Optional[str], + api_base: Optional[str], + model_name: str, + api_version: Optional[str], +) -> dict: + azure_ad_token_provider: Optional[Callable[[], str]] = None + # If we have api_key, then we have higher priority + azure_ad_token = litellm_params.get("azure_ad_token") + tenant_id = litellm_params.get("tenant_id") + client_id = litellm_params.get("client_id") + client_secret = litellm_params.get("client_secret") + azure_username = litellm_params.get("azure_username") + azure_password = litellm_params.get("azure_password") + if not api_key and tenant_id and client_id and client_secret: + verbose_logger.debug("Using Azure AD Token Provider for Azure Auth") + azure_ad_token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + if azure_username and azure_password and client_id: + azure_ad_token_provider = get_azure_ad_token_from_username_password( + azure_username=azure_username, + azure_password=azure_password, + client_id=client_id, + ) + + if azure_ad_token is not None and azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + elif ( + not api_key + and azure_ad_token_provider is None + and litellm.enable_azure_ad_token_refresh is True + ): + try: + azure_ad_token_provider = get_azure_ad_token_provider() + except ValueError: + verbose_logger.debug("Azure AD Token Provider could not be used.") + if api_version is None: + api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION) + + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" + ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider, + } + + if azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + from litellm.llms.azure.azure import select_azure_base_url_or_endpoint + + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + azure_client_params = select_azure_base_url_or_endpoint(azure_client_params) + + return azure_client_params diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 39633d448f..80e0df5202 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -8,14 +8,6 @@ import openai import litellm from litellm import get_secret, get_secret_str from litellm._logging import verbose_router_logger -from litellm.llms.azure.azure import get_azure_ad_token_from_oidc -from litellm.llms.azure.common_utils import ( - get_azure_ad_token_from_entrata_id, - get_azure_ad_token_from_username_password, -) -from litellm.secret_managers.get_azure_ad_token_provider import ( - get_azure_ad_token_provider, -) from litellm.utils import calculate_max_parallel_requests if TYPE_CHECKING: @@ -294,72 +286,3 @@ class InitalizeOpenAISDKClient: ttl=client_ttl, local_only=True, ) # cache for 1 hr - - -def initialize_azure_sdk_client( - litellm_params: dict, - api_key: Optional[str], - api_base: Optional[str], - model_name: str, - api_version: Optional[str], -): - azure_ad_token_provider: Optional[Callable[[], str]] = None - # If we have api_key, then we have higher priority - azure_ad_token = litellm_params.get("azure_ad_token") - tenant_id = litellm_params.get("tenant_id") - client_id = litellm_params.get("client_id") - client_secret = litellm_params.get("client_secret") - azure_username = litellm_params.get("azure_username") - azure_password = litellm_params.get("azure_password") - if not api_key and tenant_id and client_id and client_secret: - verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth") - azure_ad_token_provider = get_azure_ad_token_from_entrata_id( - tenant_id=tenant_id, - client_id=client_id, - client_secret=client_secret, - ) - if azure_username and azure_password and client_id: - azure_ad_token_provider = get_azure_ad_token_from_username_password( - azure_username=azure_username, - azure_password=azure_password, - client_id=client_id, - ) - - if azure_ad_token is not None and azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - elif ( - not api_key - and azure_ad_token_provider is None - and litellm.enable_azure_ad_token_refresh is True - ): - try: - azure_ad_token_provider = get_azure_ad_token_provider() - except ValueError: - verbose_router_logger.debug("Azure AD Token Provider could not be used.") - if api_version is None: - api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION) - - _api_key = api_key - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" - ) - azure_client_params = { - "api_key": api_key, - "azure_endpoint": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - "azure_ad_token_provider": azure_ad_token_provider, - } - - if azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider - from litellm.llms.azure.azure import select_azure_base_url_or_endpoint - - # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client - # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint(azure_client_params) - - return azure_client_params