diff --git a/litellm/llms/azure/audio_transcriptions.py b/litellm/llms/azure/audio_transcriptions.py index 94793295ca..69d0f5285c 100644 --- a/litellm/llms/azure/audio_transcriptions.py +++ b/litellm/llms/azure/audio_transcriptions.py @@ -32,29 +32,18 @@ class AzureAudioTranscription(AzureChatCompletion): client=None, azure_ad_token: Optional[str] = None, atranscription: bool = False, + litellm_params: Optional[dict] = None, ) -> TranscriptionResponse: data = {"model": model, "file": audio_file, **optional_params} # init AzureOpenAI Client - azure_client_params = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "timeout": timeout, - } - - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + model_name=model, + api_version=api_version, + api_base=api_base, ) - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - - if max_retries is not None: - azure_client_params["max_retries"] = max_retries if atranscription is True: return self.async_audio_transcriptions( # type: ignore diff --git a/litellm/main.py b/litellm/main.py index b0a4268106..0d80ac4943 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5066,6 +5066,7 @@ def transcription( api_version=api_version, azure_ad_token=azure_ad_token, max_retries=max_retries, + litellm_params=litellm_params_dict, ) elif ( custom_llm_provider == "openai" diff --git a/tests/litellm/llms/azure/test_azure_common_utils.py b/tests/litellm/llms/azure/test_azure_common_utils.py index e2bad3e7c5..27ec181a25 100644 --- a/tests/litellm/llms/azure/test_azure_common_utils.py +++ b/tests/litellm/llms/azure/test_azure_common_utils.py @@ -220,7 +220,7 @@ def test_select_azure_base_url_called(setup_mocks): CallTypes.atext_completion, CallTypes.aembedding, # CallTypes.arerank, - # CallTypes.atranscription, + CallTypes.atranscription, CallTypes.aspeech, CallTypes.aimage_generation, # BATCHES ENDPOINTS