diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 46ab62a8d..c292c3423 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1,42 +1,56 @@ -from typing import Optional, Union, Any, Literal, Coroutine, Iterable -from typing_extensions import overload -import types, requests -from .base import BaseLLM -from litellm.utils import ( - ModelResponse, - Choices, - Message, - CustomStreamWrapper, - convert_to_model_response_object, - TranscriptionResponse, - get_secret, - UnsupportedParamsError, -) -from typing import Callable, Optional, BinaryIO, List -from litellm import OpenAIConfig -import litellm, json -import httpx # type: ignore -from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport -from openai import AzureOpenAI, AsyncAzureOpenAI -import uuid +import asyncio +import json import os +import types +import uuid +from typing import ( + Any, + BinaryIO, + Callable, + Coroutine, + Iterable, + List, + Literal, + Optional, + Union, +) + +import httpx # type: ignore +import requests +from openai import AsyncAzureOpenAI, AzureOpenAI +from typing_extensions import overload + +import litellm +from litellm import OpenAIConfig +from litellm.caching import DualCache +from litellm.utils import ( + Choices, + CustomStreamWrapper, + Message, + ModelResponse, + TranscriptionResponse, + UnsupportedParamsError, + convert_to_model_response_object, + get_secret, +) + from ..types.llms.openai import ( - AsyncCursorPage, - AssistantToolParam, - SyncCursorPage, Assistant, - MessageData, - OpenAIMessage, - OpenAICreateThreadParamsMessage, - Thread, - AssistantToolParam, - Run, AssistantEventHandler, + AssistantStreamManager, + AssistantToolParam, AsyncAssistantEventHandler, AsyncAssistantStreamManager, - AssistantStreamManager, + AsyncCursorPage, + MessageData, + OpenAICreateThreadParamsMessage, + OpenAIMessage, + Run, + SyncCursorPage, + Thread, ) -from litellm.caching import DualCache +from .base import BaseLLM +from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport azure_ad_cache = DualCache() @@ -313,7 +327,9 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): def get_azure_ad_token_from_oidc(azure_ad_token: str): azure_client_id = os.getenv("AZURE_CLIENT_ID", None) azure_tenant_id = os.getenv("AZURE_TENANT_ID", None) - azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com") + azure_authority_host = os.getenv( + "AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com" + ) if azure_client_id is None or azure_tenant_id is None: raise AzureOpenAIError( @@ -329,12 +345,14 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): message="OIDC token could not be retrieved from secret manager.", ) - azure_ad_token_cache_key = json.dumps({ - "azure_client_id": azure_client_id, - "azure_tenant_id": azure_tenant_id, - "azure_authority_host": azure_authority_host, - "oidc_token": oidc_token, - }) + azure_ad_token_cache_key = json.dumps( + { + "azure_client_id": azure_client_id, + "azure_tenant_id": azure_tenant_id, + "azure_authority_host": azure_authority_host, + "oidc_token": oidc_token, + } + ) azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key) if azure_ad_token_access_token is not None: @@ -371,7 +389,11 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): status_code=422, message="Azure AD Token expires_in not returned" ) - azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in) + azure_ad_cache.set_cache( + key=azure_ad_token_cache_key, + value=azure_ad_token_access_token, + ttl=azure_ad_token_expires_in, + ) return azure_ad_token_access_token @@ -645,6 +667,8 @@ class AzureChatCompletion(BaseLLM): except AzureOpenAIError as e: exception_mapping_worked = True raise e + except asyncio.CancelledError as e: + raise AzureOpenAIError(status_code=500, message=str(e)) except Exception as e: if hasattr(e, "status_code"): raise e