Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:

This commit is contained in:
Ishaan Jaff 2025-03-18 14:23:14 -07:00
parent 7384d45ef0
commit 55ea2370ba
2 changed files with 16 additions and 11 deletions

View file

@ -1,5 +1,5 @@
import uuid
from typing import Any, Optional
from typing import Any, Coroutine, Optional, Union
from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel
@ -33,11 +33,11 @@ class AzureAudioTranscription(AzureChatCompletion):
azure_ad_token: Optional[str] = None,
atranscription: bool = False,
litellm_params: Optional[dict] = None,
) -> TranscriptionResponse:
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
@ -47,6 +47,8 @@ class AzureAudioTranscription(AzureChatCompletion):
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
model=model,
litellm_params=litellm_params,
)
azure_client = self.get_azure_openai_client(
@ -110,7 +112,7 @@ class AzureAudioTranscription(AzureChatCompletion):
client=None,
max_retries=None,
litellm_params: Optional[dict] = None,
):
) -> TranscriptionResponse:
response = None
try:
async_azure_client = self.get_azure_openai_client(
@ -179,7 +181,12 @@ class AzureAudioTranscription(AzureChatCompletion):
model_response_object=model_response,
hidden_params=hidden_params,
response_type="audio_transcription",
) # type: ignore
)
if not isinstance(response, TranscriptionResponse):
raise AzureOpenAIError(
status_code=500,
message="response is not an instance of TranscriptionResponse",
)
return response
except Exception as e:
## LOGGING

View file

@ -308,12 +308,10 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
# Get appropriate input for this call type
input_kwarg = test_inputs.get(call_type.value, {})
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
if call_type == CallTypes.atranscription:
patch_target = (
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
)
elif call_type == CallTypes.arerank:
patch_target = (
"litellm.llms.azure.common_utils.BaseAzureLLM.initialize_azure_sdk_client"
)
if call_type == CallTypes.arerank:
patch_target = (
"litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client"
)