refactor(azure.py): working azure client init on audio speech endpoint

This commit is contained in:
Krrish Dholakia 2025-03-11 14:19:45 -07:00
parent 858d9005a2
commit 4f4507ccc0
5 changed files with 63 additions and 46 deletions

View file

@ -60,6 +60,7 @@ def get_litellm_params(
merge_reasoning_content_in_choices: Optional[bool] = None, merge_reasoning_content_in_choices: Optional[bool] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
"api_key": api_key, "api_key": api_key,
@ -99,5 +100,11 @@ def get_litellm_params(
"async_call": async_call, "async_call": async_call,
"ssl_verify": ssl_verify, "ssl_verify": ssl_verify,
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices, "merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
"azure_ad_token": kwargs.get("azure_ad_token"),
"tenant_id": kwargs.get("tenant_id"),
"client_id": kwargs.get("client_id"),
"client_secret": kwargs.get("client_secret"),
"azure_username": kwargs.get("azure_username"),
"azure_password": kwargs.get("azure_password"),
} }
return litellm_params return litellm_params

View file

@ -153,27 +153,16 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
client: Optional[Any], client: Optional[Any],
client_type: Literal["sync", "async"], client_type: Literal["sync", "async"],
litellm_params: Optional[dict] = None,
): ):
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params: Dict[str, Any] = { azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
"api_version": api_version, litellm_params=litellm_params or {},
"azure_endpoint": api_base, api_key=api_key,
"azure_deployment": model, model_name=model,
"http_client": litellm.client_session, api_version=api_version,
"max_retries": max_retries, api_base=api_base,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
) )
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None: if client is None:
if client_type == "sync": if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore azure_client = AzureOpenAI(**azure_client_params) # type: ignore
@ -780,6 +769,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None, client=None,
aembedding=None, aembedding=None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
litellm_params: Optional[dict] = None,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
if headers: if headers:
optional_params["extra_headers"] = headers optional_params["extra_headers"] = headers
@ -795,29 +785,14 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
) )
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if aembedding:
azure_client_params["http_client"] = litellm.aclient_session
else:
azure_client_params["http_client"] = litellm.client_session
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model,
api_version=api_version,
api_base=api_base,
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1282,6 +1257,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
azure_ad_token_provider: Optional[Callable] = None, azure_ad_token_provider: Optional[Callable] = None,
aspeech: Optional[bool] = None, aspeech: Optional[bool] = None,
client=None, client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", 2)
@ -1300,6 +1276,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
max_retries=max_retries, max_retries=max_retries,
timeout=timeout, timeout=timeout,
client=client, client=client,
litellm_params=litellm_params,
) # type: ignore ) # type: ignore
azure_client: AzureOpenAI = self._get_sync_azure_client( azure_client: AzureOpenAI = self._get_sync_azure_client(
@ -1313,6 +1290,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
client_type="sync", client_type="sync",
litellm_params=litellm_params,
) # type: ignore ) # type: ignore
response = azure_client.audio.speech.create( response = azure_client.audio.speech.create(
@ -1337,6 +1315,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
max_retries: int, max_retries: int,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
client=None, client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client( azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
@ -1350,6 +1329,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
client_type="async", client_type="async",
litellm_params=litellm_params,
) # type: ignore ) # type: ignore
azure_response = await azure_client.audio.speech.create( azure_response = await azure_client.audio.speech.create(

View file

@ -58,6 +58,7 @@ def get_azure_openai_client(
data[k] = v data[k] = v
if "api_version" not in data: if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True: if _is_async is True:
openai_client = AsyncAzureOpenAI(**data) openai_client = AsyncAzureOpenAI(**data)
else: else:

View file

@ -1162,6 +1162,12 @@ def completion( # type: ignore # noqa: PLR0915
merge_reasoning_content_in_choices=kwargs.get( merge_reasoning_content_in_choices=kwargs.get(
"merge_reasoning_content_in_choices", None "merge_reasoning_content_in_choices", None
), ),
azure_ad_token=kwargs.get("azure_ad_token"),
tenant_id=kwargs.get("tenant_id"),
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
azure_username=kwargs.get("azure_username"),
azure_password=kwargs.get("azure_password"),
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -3411,6 +3417,7 @@ def embedding( # noqa: PLR0915
aembedding=aembedding, aembedding=aembedding,
max_retries=max_retries, max_retries=max_retries,
headers=headers or extra_headers, headers=headers or extra_headers,
litellm_params=litellm_params_dict,
) )
elif ( elif (
model in litellm.open_ai_embedding_models model in litellm.open_ai_embedding_models
@ -5002,6 +5009,7 @@ def transcription(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
drop_params=drop_params, drop_params=drop_params,
) )
litellm_params_dict = get_litellm_params(**kwargs)
litellm_logging_obj.update_environment_variables( litellm_logging_obj.update_environment_variables(
model=model, model=model,
@ -5198,7 +5206,7 @@ def speech(
if max_retries is None: if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
litellm_params_dict = get_litellm_params(**kwargs)
logging_obj = kwargs.get("litellm_logging_obj", None) logging_obj = kwargs.get("litellm_logging_obj", None)
logging_obj.update_environment_variables( logging_obj.update_environment_variables(
model=model, model=model,
@ -5315,6 +5323,7 @@ def speech(
timeout=timeout, timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech, aspeech=aspeech,
litellm_params=litellm_params_dict,
) )
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":

View file

@ -219,8 +219,12 @@ def test_select_azure_base_url_called(setup_mocks):
CallTypes.acompletion, CallTypes.acompletion,
CallTypes.atext_completion, CallTypes.atext_completion,
CallTypes.aembedding, CallTypes.aembedding,
CallTypes.arerank, # CallTypes.arerank,
CallTypes.atranscription, # CallTypes.atranscription,
CallTypes.aspeech,
CallTypes.aimage_generation,
# BATCHES ENDPOINTS
# ASSISTANT ENDPOINTS
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -255,15 +259,20 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
"aembedding": {"input": "Hello, how are you?"}, "aembedding": {"input": "Hello, how are you?"},
"arerank": {"input": "Hello, how are you?"}, "arerank": {"input": "Hello, how are you?"},
"atranscription": {"file": "path/to/file"}, "atranscription": {"file": "path/to/file"},
"aspeech": {"input": "Hello, how are you?", "voice": "female"},
} }
# Get appropriate input for this call type # Get appropriate input for this call type
input_kwarg = test_inputs.get(call_type.value, {}) input_kwarg = test_inputs.get(call_type.value, {})
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
if call_type == CallTypes.atranscription:
patch_target = (
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
)
# Mock the initialize_azure_sdk_client function # Mock the initialize_azure_sdk_client function
with patch( with patch(patch_target) as mock_init_azure:
"litellm.main.azure_chat_completions.initialize_azure_sdk_client"
) as mock_init_azure:
# Also mock async_function_with_fallbacks to prevent actual API calls # Also mock async_function_with_fallbacks to prevent actual API calls
# Call the appropriate router method # Call the appropriate router method
try: try:
@ -271,6 +280,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
**input_kwarg, **input_kwarg,
num_retries=0, num_retries=0,
azure_ad_token="oidc/test-token",
) )
except Exception as e: except Exception as e:
print(e) print(e)
@ -282,6 +292,16 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
calls = mock_init_azure.call_args_list calls = mock_init_azure.call_args_list
azure_calls = [call for call in calls] azure_calls = [call for call in calls]
litellm_params = azure_calls[0].kwargs["litellm_params"]
print("litellm_params", litellm_params)
assert (
"azure_ad_token" in litellm_params
), "azure_ad_token not found in parameters"
assert (
litellm_params["azure_ad_token"] == "oidc/test-token"
), "azure_ad_token is not correct"
# More detailed verification (optional) # More detailed verification (optional)
for call in azure_calls: for call in azure_calls:
assert "api_key" in call.kwargs, "api_key not found in parameters" assert "api_key" in call.kwargs, "api_key not found in parameters"