From 934c06c2074004ff59153f7e4808a241a82b468e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 11 Mar 2025 17:42:36 -0700 Subject: [PATCH] test: fix tests --- litellm/litellm_core_utils/get_litellm_params.py | 4 +++- litellm/llms/azure/azure.py | 1 + litellm/llms/azure/common_utils.py | 5 ++++- litellm/main.py | 3 +++ tests/llm_translation/test_azure_openai.py | 10 ++++++---- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py index d061eeb219..d1166e157b 100644 --- a/litellm/litellm_core_utils/get_litellm_params.py +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -58,9 +58,9 @@ def get_litellm_params( async_call: Optional[bool] = None, ssl_verify: Optional[bool] = None, merge_reasoning_content_in_choices: Optional[bool] = None, + max_retries: Optional[int] = None, **kwargs, ) -> dict: - litellm_params = { "acompletion": acompletion, "api_key": api_key, @@ -106,5 +106,7 @@ def get_litellm_params( "client_secret": kwargs.get("client_secret"), "azure_username": kwargs.get("azure_username"), "azure_password": kwargs.get("azure_password"), + "max_retries": max_retries, + "timeout": kwargs.get("timeout"), } return litellm_params diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index d0875412f6..0f155af427 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -718,6 +718,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): ): response = None try: + if client is None: openai_aclient = AsyncAzureOpenAI(**azure_client_params) else: diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index e7795f78cb..d409839c4d 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -342,6 +342,9 @@ class BaseAzureLLM: azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint(azure_client_params) + + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) return azure_client_params diff --git a/litellm/main.py b/litellm/main.py index 0d80ac4943..02f69192a2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1168,6 +1168,8 @@ def completion( # type: ignore # noqa: PLR0915 client_secret=kwargs.get("client_secret"), azure_username=kwargs.get("azure_username"), azure_password=kwargs.get("azure_password"), + max_retries=max_retries, + timeout=timeout, ) logging.update_environment_variables( model=model, @@ -3356,6 +3358,7 @@ def embedding( # noqa: PLR0915 } } ) + litellm_params_dict = get_litellm_params(**kwargs) logging: Logging = litellm_logging_obj # type: ignore diff --git a/tests/llm_translation/test_azure_openai.py b/tests/llm_translation/test_azure_openai.py index d4715b8906..ef5fd69b76 100644 --- a/tests/llm_translation/test_azure_openai.py +++ b/tests/llm_translation/test_azure_openai.py @@ -556,12 +556,11 @@ async def test_azure_instruct( @pytest.mark.parametrize("max_retries", [0, 4]) -@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False]) -@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint") +@patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint") @pytest.mark.asyncio async def test_azure_embedding_max_retries_0( - mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode + mock_select_azure_base_url_or_endpoint, max_retries, sync_mode ): from litellm import aembedding, embedding @@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0( "model": "azure/azure-embedding-model", "input": "Hello world", "max_retries": max_retries, - "stream": stream, } try: @@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0( print(e) mock_select_azure_base_url_or_endpoint.assert_called_once() + print( + "mock_select_azure_base_url_or_endpoint.call_args.kwargs", + mock_select_azure_base_url_or_endpoint.call_args.kwargs, + ) assert ( mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][ "max_retries"