fix(azure.py): handle asyncio.CancelledError

This commit is contained in:
Krrish Dholakia 2024-06-18 20:14:05 -07:00
parent 7b414a73a7
commit aa6b56f057

View file

@ -1,42 +1,56 @@
from typing import Optional, Union, Any, Literal, Coroutine, Iterable
from typing_extensions import overload
import types, requests
from .base import BaseLLM
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
get_secret,
UnsupportedParamsError,
)
from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig
import litellm, json
import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
import asyncio
import json
import os
import types
import uuid
from typing import (
Any,
BinaryIO,
Callable,
Coroutine,
Iterable,
List,
Literal,
Optional,
Union,
)
import httpx # type: ignore
import requests
from openai import AsyncAzureOpenAI, AzureOpenAI
from typing_extensions import overload
import litellm
from litellm import OpenAIConfig
from litellm.caching import DualCache
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TranscriptionResponse,
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
)
from ..types.llms.openai import (
AsyncCursorPage,
AssistantToolParam,
SyncCursorPage,
Assistant,
MessageData,
OpenAIMessage,
OpenAICreateThreadParamsMessage,
Thread,
AssistantToolParam,
Run,
AssistantEventHandler,
AssistantStreamManager,
AssistantToolParam,
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AssistantStreamManager,
AsyncCursorPage,
MessageData,
OpenAICreateThreadParamsMessage,
OpenAIMessage,
Run,
SyncCursorPage,
Thread,
)
from litellm.caching import DualCache
from .base import BaseLLM
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
azure_ad_cache = DualCache()
@ -313,7 +327,9 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com")
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
@ -329,12 +345,14 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps({
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
})
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
@ -371,7 +389,11 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
@ -645,6 +667,8 @@ class AzureChatCompletion(BaseLLM):
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except asyncio.CancelledError as e:
raise AzureOpenAIError(status_code=500, message=str(e))
except Exception as e:
if hasattr(e, "status_code"):
raise e