diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e8bcaff64b..46ab62a8d3 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -36,6 +36,9 @@ from ..types.llms.openai import ( AsyncAssistantStreamManager, AssistantStreamManager, ) +from litellm.caching import DualCache + +azure_ad_cache = DualCache() class AzureOpenAIError(Exception): @@ -326,6 +329,17 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): message="OIDC token could not be retrieved from secret manager.", ) + azure_ad_token_cache_key = json.dumps({ + "azure_client_id": azure_client_id, + "azure_tenant_id": azure_tenant_id, + "azure_authority_host": azure_authority_host, + "oidc_token": oidc_token, + }) + + azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key) + if azure_ad_token_access_token is not None: + return azure_ad_token_access_token + req_token = httpx.post( f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token", data={ @@ -343,12 +357,23 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): message=req_token.text, ) - possible_azure_ad_token = req_token.json().get("access_token", None) + azure_ad_token_json = req_token.json() + azure_ad_token_access_token = azure_ad_token_json.get("access_token", None) + azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None) - if possible_azure_ad_token is None: - raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned") + if azure_ad_token_access_token is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token access_token not returned" + ) - return possible_azure_ad_token + if azure_ad_token_expires_in is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token expires_in not returned" + ) + + azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in) + + return azure_ad_token_access_token class AzureChatCompletion(BaseLLM):