diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e7af9d43b..c2bbe54c1 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -8,6 +8,7 @@ from litellm.utils import ( CustomStreamWrapper, convert_to_model_response_object, TranscriptionResponse, + get_secret, ) from typing import Callable, Optional, BinaryIO from litellm import OpenAIConfig @@ -16,6 +17,7 @@ import httpx from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI import uuid +import os class AzureOpenAIError(Exception): @@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): return azure_client_params +def get_azure_ad_token_from_oidc(azure_ad_token: str): + azure_client_id = os.getenv("AZURE_CLIENT_ID", None) + azure_tenant = os.getenv("AZURE_TENANT_ID", None) + + if azure_client_id is None or azure_tenant is None: + raise AzureOpenAIError( + status_code=422, + message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", + ) + + oidc_token = get_secret(azure_ad_token) + + if oidc_token is None: + raise AzureOpenAIError( + status_code=401, + message="OIDC token could not be retrieved from secret manager.", + ) + + req_token = httpx.get( + f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token", + data={ + "client_id": azure_client_id, + "grant_type": "client_credentials", + "scope": "https://cognitiveservices.azure.com/.default", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": oidc_token, + }, + ) + + if req_token.status_code != 200: + raise AzureOpenAIError( + status_code=req_token.status_code, + message=req_token.text, + ) + + possible_azure_ad_token = req_token.json().get("access_token", None) + + if possible_azure_ad_token is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token not returned" + ) + + return possible_azure_ad_token + + class AzureChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: headers["api-key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) headers["Authorization"] = f"Bearer {azure_ad_token}" return headers @@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + azure_client_params["azure_ad_token"] = azure_ad_token if acompletion is True: @@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if client is None: azure_client = AzureOpenAI(**azure_client_params) @@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token # setting Azure client @@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if client is None: azure_client = AzureOpenAI(**azure_client_params) @@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if client is None: azure_client = AsyncAzureOpenAI(**azure_client_params) @@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token ## LOGGING @@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if aimg_generation == True: @@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if max_retries is not None: