forked from phoenix/litellm-mirror
fix(azure.py): handle asyncio.CancelledError
This commit is contained in:
parent
7b414a73a7
commit
aa6b56f057
1 changed files with 64 additions and 40 deletions
|
@ -1,42 +1,56 @@
|
|||
from typing import Optional, Union, Any, Literal, Coroutine, Iterable
|
||||
from typing_extensions import overload
|
||||
import types, requests
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Choices,
|
||||
Message,
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
TranscriptionResponse,
|
||||
get_secret,
|
||||
UnsupportedParamsError,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO, List
|
||||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx # type: ignore
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import types
|
||||
import uuid
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
from litellm import OpenAIConfig
|
||||
from litellm.caching import DualCache
|
||||
from litellm.utils import (
|
||||
Choices,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
ModelResponse,
|
||||
TranscriptionResponse,
|
||||
UnsupportedParamsError,
|
||||
convert_to_model_response_object,
|
||||
get_secret,
|
||||
)
|
||||
|
||||
from ..types.llms.openai import (
|
||||
AsyncCursorPage,
|
||||
AssistantToolParam,
|
||||
SyncCursorPage,
|
||||
Assistant,
|
||||
MessageData,
|
||||
OpenAIMessage,
|
||||
OpenAICreateThreadParamsMessage,
|
||||
Thread,
|
||||
AssistantToolParam,
|
||||
Run,
|
||||
AssistantEventHandler,
|
||||
AssistantStreamManager,
|
||||
AssistantToolParam,
|
||||
AsyncAssistantEventHandler,
|
||||
AsyncAssistantStreamManager,
|
||||
AssistantStreamManager,
|
||||
AsyncCursorPage,
|
||||
MessageData,
|
||||
OpenAICreateThreadParamsMessage,
|
||||
OpenAIMessage,
|
||||
Run,
|
||||
SyncCursorPage,
|
||||
Thread,
|
||||
)
|
||||
from litellm.caching import DualCache
|
||||
from .base import BaseLLM
|
||||
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
|
||||
|
@ -313,7 +327,9 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
|||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com")
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
|
@ -329,12 +345,14 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
|||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps({
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
})
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
|
@ -371,7 +389,11 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
|||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in)
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
@ -645,6 +667,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except asyncio.CancelledError as e:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue