refactor(azure/audio_transcriptions.py): support client init with common logic

This commit is contained in:
Krrish Dholakia 2025-03-11 14:24:12 -07:00
parent 2c2404dac9
commit af71e14d79
3 changed files with 9 additions and 19 deletions

View file

@ -32,29 +32,18 @@ class AzureAudioTranscription(AzureChatCompletion):
client=None, client=None,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
atranscription: bool = False, atranscription: bool = False,
litellm_params: Optional[dict] = None,
) -> TranscriptionResponse: ) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params} data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = self.initialize_azure_sdk_client(
"api_version": api_version, litellm_params=litellm_params or {},
"azure_endpoint": api_base, api_key=api_key,
"azure_deployment": model, model_name=model,
"timeout": timeout, api_version=api_version,
} api_base=api_base,
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
) )
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if atranscription is True: if atranscription is True:
return self.async_audio_transcriptions( # type: ignore return self.async_audio_transcriptions( # type: ignore

View file

@ -5066,6 +5066,7 @@ def transcription(
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
max_retries=max_retries, max_retries=max_retries,
litellm_params=litellm_params_dict,
) )
elif ( elif (
custom_llm_provider == "openai" custom_llm_provider == "openai"

View file

@ -220,7 +220,7 @@ def test_select_azure_base_url_called(setup_mocks):
CallTypes.atext_completion, CallTypes.atext_completion,
CallTypes.aembedding, CallTypes.aembedding,
# CallTypes.arerank, # CallTypes.arerank,
# CallTypes.atranscription, CallTypes.atranscription,
CallTypes.aspeech, CallTypes.aspeech,
CallTypes.aimage_generation, CallTypes.aimage_generation,
# BATCHES ENDPOINTS # BATCHES ENDPOINTS