feat(util.py/azure.py): Add OIDC support when running in Azure Kubernetes Service (AKS).

This commit is contained in:
David Manouchehri 2024-05-27 16:33:37 +00:00
parent 857df1d6af
commit a31fa5fbc8
No known key found for this signature in database
2 changed files with 12 additions and 3 deletions

View file

@ -309,9 +309,10 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
def get_azure_ad_token_from_oidc(azure_ad_token: str): def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None) azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None) azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com")
if azure_client_id is None or azure_tenant is None: if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError( raise AzureOpenAIError(
status_code=422, status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
@ -326,7 +327,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
) )
req_token = httpx.post( req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token", f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={ data={
"client_id": azure_client_id, "client_id": azure_client_id,
"grant_type": "client_credentials", "grant_type": "client_credentials",

View file

@ -10050,6 +10050,14 @@ def get_secret(
return oidc_token return oidc_token
else: else:
raise ValueError("Github OIDC provider failed") raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
else: else:
raise ValueError("Unsupported OIDC provider") raise ValueError("Unsupported OIDC provider")