diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py index cf62375f33..d061eeb219 100644 --- a/litellm/litellm_core_utils/get_litellm_params.py +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -60,6 +60,7 @@ def get_litellm_params( merge_reasoning_content_in_choices: Optional[bool] = None, **kwargs, ) -> dict: + litellm_params = { "acompletion": acompletion, "api_key": api_key, @@ -99,5 +100,11 @@ def get_litellm_params( "async_call": async_call, "ssl_verify": ssl_verify, "merge_reasoning_content_in_choices": merge_reasoning_content_in_choices, + "azure_ad_token": kwargs.get("azure_ad_token"), + "tenant_id": kwargs.get("tenant_id"), + "client_id": kwargs.get("client_id"), + "client_secret": kwargs.get("client_secret"), + "azure_username": kwargs.get("azure_username"), + "azure_password": kwargs.get("azure_password"), } return litellm_params diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 4575942d58..84e02bbf95 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -153,27 +153,16 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): timeout: Union[float, httpx.Timeout], client: Optional[Any], client_type: Literal["sync", "async"], + litellm_params: Optional[dict] = None, ): # init AzureOpenAI Client - azure_client_params: Dict[str, Any] = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "http_client": litellm.client_session, - "max_retries": max_retries, - "timeout": timeout, - } - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params + azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + model_name=model, + api_version=api_version, + api_base=api_base, ) - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider if client is None: if client_type == "sync": azure_client = AzureOpenAI(**azure_client_params) # type: ignore @@ -780,6 +769,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): client=None, aembedding=None, headers: Optional[dict] = None, + litellm_params: Optional[dict] = None, ) -> EmbeddingResponse: if headers: optional_params["extra_headers"] = headers @@ -795,29 +785,14 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): ) # init AzureOpenAI Client - azure_client_params = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "max_retries": max_retries, - "timeout": timeout, - } - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params - ) - if aembedding: - azure_client_params["http_client"] = litellm.aclient_session - else: - azure_client_params["http_client"] = litellm.client_session - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + model_name=model, + api_version=api_version, + api_base=api_base, + ) ## LOGGING logging_obj.pre_call( input=input, @@ -1282,6 +1257,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): azure_ad_token_provider: Optional[Callable] = None, aspeech: Optional[bool] = None, client=None, + litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: max_retries = optional_params.pop("max_retries", 2) @@ -1300,6 +1276,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): max_retries=max_retries, timeout=timeout, client=client, + litellm_params=litellm_params, ) # type: ignore azure_client: AzureOpenAI = self._get_sync_azure_client( @@ -1313,6 +1290,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): timeout=timeout, client=client, client_type="sync", + litellm_params=litellm_params, ) # type: ignore response = azure_client.audio.speech.create( @@ -1337,6 +1315,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): max_retries: int, timeout: Union[float, httpx.Timeout], client=None, + litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: azure_client: AsyncAzureOpenAI = self._get_sync_azure_client( @@ -1350,6 +1329,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): timeout=timeout, client=client, client_type="async", + litellm_params=litellm_params, ) # type: ignore azure_response = await azure_client.audio.speech.create( diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index d70554c2d2..272f5e86a9 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -58,6 +58,7 @@ def get_azure_openai_client( data[k] = v if "api_version" not in data: data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION + if _is_async is True: openai_client = AsyncAzureOpenAI(**data) else: diff --git a/litellm/main.py b/litellm/main.py index 846a908a8e..997c1ae75d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1162,6 +1162,12 @@ def completion( # type: ignore # noqa: PLR0915 merge_reasoning_content_in_choices=kwargs.get( "merge_reasoning_content_in_choices", None ), + azure_ad_token=kwargs.get("azure_ad_token"), + tenant_id=kwargs.get("tenant_id"), + client_id=kwargs.get("client_id"), + client_secret=kwargs.get("client_secret"), + azure_username=kwargs.get("azure_username"), + azure_password=kwargs.get("azure_password"), ) logging.update_environment_variables( model=model, @@ -3411,6 +3417,7 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, max_retries=max_retries, headers=headers or extra_headers, + litellm_params=litellm_params_dict, ) elif ( model in litellm.open_ai_embedding_models @@ -5002,6 +5009,7 @@ def transcription( custom_llm_provider=custom_llm_provider, drop_params=drop_params, ) + litellm_params_dict = get_litellm_params(**kwargs) litellm_logging_obj.update_environment_variables( model=model, @@ -5198,7 +5206,7 @@ def speech( if max_retries is None: max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES - + litellm_params_dict = get_litellm_params(**kwargs) logging_obj = kwargs.get("litellm_logging_obj", None) logging_obj.update_environment_variables( model=model, @@ -5315,6 +5323,7 @@ def speech( timeout=timeout, client=client, # pass AsyncOpenAI, OpenAI client aspeech=aspeech, + litellm_params=litellm_params_dict, ) elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": diff --git a/tests/litellm/llms/azure/test_azure_common_utils.py b/tests/litellm/llms/azure/test_azure_common_utils.py index 6f1d86450f..e2bad3e7c5 100644 --- a/tests/litellm/llms/azure/test_azure_common_utils.py +++ b/tests/litellm/llms/azure/test_azure_common_utils.py @@ -219,8 +219,12 @@ def test_select_azure_base_url_called(setup_mocks): CallTypes.acompletion, CallTypes.atext_completion, CallTypes.aembedding, - CallTypes.arerank, - CallTypes.atranscription, + # CallTypes.arerank, + # CallTypes.atranscription, + CallTypes.aspeech, + CallTypes.aimage_generation, + # BATCHES ENDPOINTS + # ASSISTANT ENDPOINTS ], ) @pytest.mark.asyncio @@ -255,15 +259,20 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type): "aembedding": {"input": "Hello, how are you?"}, "arerank": {"input": "Hello, how are you?"}, "atranscription": {"file": "path/to/file"}, + "aspeech": {"input": "Hello, how are you?", "voice": "female"}, } # Get appropriate input for this call type input_kwarg = test_inputs.get(call_type.value, {}) + patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client" + if call_type == CallTypes.atranscription: + patch_target = ( + "litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client" + ) + # Mock the initialize_azure_sdk_client function - with patch( - "litellm.main.azure_chat_completions.initialize_azure_sdk_client" - ) as mock_init_azure: + with patch(patch_target) as mock_init_azure: # Also mock async_function_with_fallbacks to prevent actual API calls # Call the appropriate router method try: @@ -271,6 +280,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type): model="gpt-3.5-turbo", **input_kwarg, num_retries=0, + azure_ad_token="oidc/test-token", ) except Exception as e: print(e) @@ -282,6 +292,16 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type): calls = mock_init_azure.call_args_list azure_calls = [call for call in calls] + litellm_params = azure_calls[0].kwargs["litellm_params"] + print("litellm_params", litellm_params) + + assert ( + "azure_ad_token" in litellm_params + ), "azure_ad_token not found in parameters" + assert ( + litellm_params["azure_ad_token"] == "oidc/test-token" + ), "azure_ad_token is not correct" + # More detailed verification (optional) for call in azure_calls: assert "api_key" in call.kwargs, "api_key not found in parameters"