forked from phoenix/litellm-mirror
Merge pull request #3861 from Manouchehri/aks-oidc-1852
feat(util.py/azure.py): Add OIDC support when running LiteLLM on Azure + Azure Upstream caching
This commit is contained in:
commit
821d32fe17
2 changed files with 41 additions and 7 deletions
|
@ -36,6 +36,9 @@ from ..types.llms.openai import (
|
||||||
AsyncAssistantStreamManager,
|
AsyncAssistantStreamManager,
|
||||||
AssistantStreamManager,
|
AssistantStreamManager,
|
||||||
)
|
)
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
azure_ad_cache = DualCache()
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -309,9 +312,10 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
|
|
||||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||||
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||||
|
azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com")
|
||||||
|
|
||||||
if azure_client_id is None or azure_tenant is None:
|
if azure_client_id is None or azure_tenant_id is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||||
|
@ -325,8 +329,19 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
azure_ad_token_cache_key = json.dumps({
|
||||||
|
"azure_client_id": azure_client_id,
|
||||||
|
"azure_tenant_id": azure_tenant_id,
|
||||||
|
"azure_authority_host": azure_authority_host,
|
||||||
|
"oidc_token": oidc_token,
|
||||||
|
})
|
||||||
|
|
||||||
|
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||||
|
if azure_ad_token_access_token is not None:
|
||||||
|
return azure_ad_token_access_token
|
||||||
|
|
||||||
req_token = httpx.post(
|
req_token = httpx.post(
|
||||||
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||||
data={
|
data={
|
||||||
"client_id": azure_client_id,
|
"client_id": azure_client_id,
|
||||||
"grant_type": "client_credentials",
|
"grant_type": "client_credentials",
|
||||||
|
@ -342,12 +357,23 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
message=req_token.text,
|
message=req_token.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
possible_azure_ad_token = req_token.json().get("access_token", None)
|
azure_ad_token_json = req_token.json()
|
||||||
|
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||||
|
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||||
|
|
||||||
if possible_azure_ad_token is None:
|
if azure_ad_token_access_token is None:
|
||||||
raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token access_token not returned"
|
||||||
|
)
|
||||||
|
|
||||||
return possible_azure_ad_token
|
if azure_ad_token_expires_in is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token expires_in not returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in)
|
||||||
|
|
||||||
|
return azure_ad_token_access_token
|
||||||
|
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
|
|
|
@ -10066,6 +10066,14 @@ def get_secret(
|
||||||
return oidc_token
|
return oidc_token
|
||||||
else:
|
else:
|
||||||
raise ValueError("Github OIDC provider failed")
|
raise ValueError("Github OIDC provider failed")
|
||||||
|
elif oidc_provider == "azure":
|
||||||
|
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||||
|
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||||
|
if azure_federated_token_file is None:
|
||||||
|
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||||
|
with open(azure_federated_token_file, "r") as f:
|
||||||
|
oidc_token = f.read()
|
||||||
|
return oidc_token
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported OIDC provider")
|
raise ValueError("Unsupported OIDC provider")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue