forked from phoenix/litellm-mirror
fix(azure.py): handle asyncio.CancelledError
This commit is contained in:
parent
7b414a73a7
commit
aa6b56f057
1 changed files with 64 additions and 40 deletions
|
@ -1,42 +1,56 @@
|
||||||
from typing import Optional, Union, Any, Literal, Coroutine, Iterable
|
import asyncio
|
||||||
from typing_extensions import overload
|
import json
|
||||||
import types, requests
|
|
||||||
from .base import BaseLLM
|
|
||||||
from litellm.utils import (
|
|
||||||
ModelResponse,
|
|
||||||
Choices,
|
|
||||||
Message,
|
|
||||||
CustomStreamWrapper,
|
|
||||||
convert_to_model_response_object,
|
|
||||||
TranscriptionResponse,
|
|
||||||
get_secret,
|
|
||||||
UnsupportedParamsError,
|
|
||||||
)
|
|
||||||
from typing import Callable, Optional, BinaryIO, List
|
|
||||||
from litellm import OpenAIConfig
|
|
||||||
import litellm, json
|
|
||||||
import httpx # type: ignore
|
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
||||||
import uuid
|
|
||||||
import os
|
import os
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import OpenAIConfig
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.utils import (
|
||||||
|
Choices,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
UnsupportedParamsError,
|
||||||
|
convert_to_model_response_object,
|
||||||
|
get_secret,
|
||||||
|
)
|
||||||
|
|
||||||
from ..types.llms.openai import (
|
from ..types.llms.openai import (
|
||||||
AsyncCursorPage,
|
|
||||||
AssistantToolParam,
|
|
||||||
SyncCursorPage,
|
|
||||||
Assistant,
|
Assistant,
|
||||||
MessageData,
|
|
||||||
OpenAIMessage,
|
|
||||||
OpenAICreateThreadParamsMessage,
|
|
||||||
Thread,
|
|
||||||
AssistantToolParam,
|
|
||||||
Run,
|
|
||||||
AssistantEventHandler,
|
AssistantEventHandler,
|
||||||
|
AssistantStreamManager,
|
||||||
|
AssistantToolParam,
|
||||||
AsyncAssistantEventHandler,
|
AsyncAssistantEventHandler,
|
||||||
AsyncAssistantStreamManager,
|
AsyncAssistantStreamManager,
|
||||||
AssistantStreamManager,
|
AsyncCursorPage,
|
||||||
|
MessageData,
|
||||||
|
OpenAICreateThreadParamsMessage,
|
||||||
|
OpenAIMessage,
|
||||||
|
Run,
|
||||||
|
SyncCursorPage,
|
||||||
|
Thread,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache
|
from .base import BaseLLM
|
||||||
|
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
|
||||||
|
|
||||||
azure_ad_cache = DualCache()
|
azure_ad_cache = DualCache()
|
||||||
|
|
||||||
|
@ -313,7 +327,9 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||||
azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com")
|
azure_authority_host = os.getenv(
|
||||||
|
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||||
|
)
|
||||||
|
|
||||||
if azure_client_id is None or azure_tenant_id is None:
|
if azure_client_id is None or azure_tenant_id is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
|
@ -329,12 +345,14 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_ad_token_cache_key = json.dumps({
|
azure_ad_token_cache_key = json.dumps(
|
||||||
|
{
|
||||||
"azure_client_id": azure_client_id,
|
"azure_client_id": azure_client_id,
|
||||||
"azure_tenant_id": azure_tenant_id,
|
"azure_tenant_id": azure_tenant_id,
|
||||||
"azure_authority_host": azure_authority_host,
|
"azure_authority_host": azure_authority_host,
|
||||||
"oidc_token": oidc_token,
|
"oidc_token": oidc_token,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||||
if azure_ad_token_access_token is not None:
|
if azure_ad_token_access_token is not None:
|
||||||
|
@ -371,7 +389,11 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
status_code=422, message="Azure AD Token expires_in not returned"
|
status_code=422, message="Azure AD Token expires_in not returned"
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in)
|
azure_ad_cache.set_cache(
|
||||||
|
key=azure_ad_token_cache_key,
|
||||||
|
value=azure_ad_token_access_token,
|
||||||
|
ttl=azure_ad_token_expires_in,
|
||||||
|
)
|
||||||
|
|
||||||
return azure_ad_token_access_token
|
return azure_ad_token_access_token
|
||||||
|
|
||||||
|
@ -645,6 +667,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
if hasattr(e, "status_code"):
|
||||||
raise e
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue