test: fix tests

This commit is contained in:
Krrish Dholakia 2025-03-11 17:42:36 -07:00
parent cbc2e84044
commit 9af73f339a
5 changed files with 17 additions and 6 deletions

View file

@ -58,9 +58,9 @@ def get_litellm_params(
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
merge_reasoning_content_in_choices: Optional[bool] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> dict:
litellm_params = {
"acompletion": acompletion,
"api_key": api_key,
@ -106,5 +106,7 @@ def get_litellm_params(
"client_secret": kwargs.get("client_secret"),
"azure_username": kwargs.get("azure_username"),
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
}
return litellm_params

View file

@ -718,6 +718,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
):
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:

View file

@ -342,6 +342,9 @@ class BaseAzureLLM:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(azure_client_params)
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
return azure_client_params

View file

@ -1168,6 +1168,8 @@ def completion( # type: ignore # noqa: PLR0915
client_secret=kwargs.get("client_secret"),
azure_username=kwargs.get("azure_username"),
azure_password=kwargs.get("azure_password"),
max_retries=max_retries,
timeout=timeout,
)
logging.update_environment_variables(
model=model,
@ -3356,6 +3358,7 @@ def embedding( # noqa: PLR0915
}
}
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj # type: ignore

View file

@ -556,12 +556,11 @@ async def test_azure_instruct(
@pytest.mark.parametrize("max_retries", [0, 4])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("sync_mode", [True, False])
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint")
@patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint")
@pytest.mark.asyncio
async def test_azure_embedding_max_retries_0(
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
mock_select_azure_base_url_or_endpoint, max_retries, sync_mode
):
from litellm import aembedding, embedding
@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0(
"model": "azure/azure-embedding-model",
"input": "Hello world",
"max_retries": max_retries,
"stream": stream,
}
try:
@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0(
print(e)
mock_select_azure_base_url_or_endpoint.assert_called_once()
print(
"mock_select_azure_base_url_or_endpoint.call_args.kwargs",
mock_select_azure_base_url_or_endpoint.call_args.kwargs,
)
assert (
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
"max_retries"