From a31fa5fbc80c29ef3b33ca675e9ff9b3c85b8af3 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Mon, 27 May 2024 16:33:37 +0000 Subject: [PATCH] feat(util.py/azure.py): Add OIDC support when running in Azure Kubernetes Service (AKS). --- litellm/llms/azure.py | 7 ++++--- litellm/utils.py | 8 ++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 834fcbea9..e8bcaff64 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -309,9 +309,10 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): def get_azure_ad_token_from_oidc(azure_ad_token: str): azure_client_id = os.getenv("AZURE_CLIENT_ID", None) - azure_tenant = os.getenv("AZURE_TENANT_ID", None) + azure_tenant_id = os.getenv("AZURE_TENANT_ID", None) + azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com") - if azure_client_id is None or azure_tenant is None: + if azure_client_id is None or azure_tenant_id is None: raise AzureOpenAIError( status_code=422, message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", @@ -326,7 +327,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): ) req_token = httpx.post( - f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token", + f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token", data={ "client_id": azure_client_id, "grant_type": "client_credentials", diff --git a/litellm/utils.py b/litellm/utils.py index 5e85419dc..b872687b5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10050,6 +10050,14 @@ def get_secret( return oidc_token else: raise ValueError("Github OIDC provider failed") + elif oidc_provider == "azure": + # https://azure.github.io/azure-workload-identity/docs/quick-start.html + azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") + if azure_federated_token_file is None: + raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") + with open(azure_federated_token_file, "r") as f: + oidc_token = f.read() + return oidc_token else: raise ValueError("Unsupported OIDC provider")