Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:

This commit is contained in:
Ishaan Jaff 2025-03-18 14:23:14 -07:00
parent 7384d45ef0
commit 55ea2370ba
2 changed files with 16 additions and 11 deletions

View file

@ -1,5 +1,5 @@
import uuid import uuid
from typing import Any, Optional from typing import Any, Coroutine, Optional, Union
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel from pydantic import BaseModel
@ -33,11 +33,11 @@ class AzureAudioTranscription(AzureChatCompletion):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
atranscription: bool = False, atranscription: bool = False,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> TranscriptionResponse: ) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
data = {"model": model, "file": audio_file, **optional_params} data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True: if atranscription is True:
return self.async_audio_transcriptions( # type: ignore return self.async_audio_transcriptions(
audio_file=audio_file, audio_file=audio_file,
data=data, data=data,
model_response=model_response, model_response=model_response,
@ -47,6 +47,8 @@ class AzureAudioTranscription(AzureChatCompletion):
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
logging_obj=logging_obj, logging_obj=logging_obj,
model=model,
litellm_params=litellm_params,
) )
azure_client = self.get_azure_openai_client( azure_client = self.get_azure_openai_client(
@ -110,7 +112,7 @@ class AzureAudioTranscription(AzureChatCompletion):
client=None, client=None,
max_retries=None, max_retries=None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ) -> TranscriptionResponse:
response = None response = None
try: try:
async_azure_client = self.get_azure_openai_client( async_azure_client = self.get_azure_openai_client(
@ -179,7 +181,12 @@ class AzureAudioTranscription(AzureChatCompletion):
model_response_object=model_response, model_response_object=model_response,
hidden_params=hidden_params, hidden_params=hidden_params,
response_type="audio_transcription", response_type="audio_transcription",
) # type: ignore )
if not isinstance(response, TranscriptionResponse):
raise AzureOpenAIError(
status_code=500,
message="response is not an instance of TranscriptionResponse",
)
return response return response
except Exception as e: except Exception as e:
## LOGGING ## LOGGING

View file

@ -308,12 +308,10 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
# Get appropriate input for this call type # Get appropriate input for this call type
input_kwarg = test_inputs.get(call_type.value, {}) input_kwarg = test_inputs.get(call_type.value, {})
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
if call_type == CallTypes.atranscription:
patch_target = ( patch_target = (
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client" "litellm.llms.azure.common_utils.BaseAzureLLM.initialize_azure_sdk_client"
) )
elif call_type == CallTypes.arerank: if call_type == CallTypes.arerank:
patch_target = ( patch_target = (
"litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client" "litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client"
) )