refactor(azure/audio_transcriptions.py): support client init with common logic

This commit is contained in:
Krrish Dholakia 2025-03-11 14:24:12 -07:00
parent 2c2404dac9
commit af71e14d79
3 changed files with 9 additions and 19 deletions

View file

@ -32,29 +32,18 @@ class AzureAudioTranscription(AzureChatCompletion):
client=None,
azure_ad_token: Optional[str] = None,
atranscription: bool = False,
litellm_params: Optional[dict] = None,
) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model,
api_version=api_version,
api_base=api_base,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore

View file

@ -5066,6 +5066,7 @@ def transcription(
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"

View file

@ -220,7 +220,7 @@ def test_select_azure_base_url_called(setup_mocks):
CallTypes.atext_completion,
CallTypes.aembedding,
# CallTypes.arerank,
# CallTypes.atranscription,
CallTypes.atranscription,
CallTypes.aspeech,
CallTypes.aimage_generation,
# BATCHES ENDPOINTS