test: fix tests

This commit is contained in:
Krrish Dholakia 2025-03-11 17:42:36 -07:00
parent cbc2e84044
commit 9af73f339a
5 changed files with 17 additions and 6 deletions

View file

@ -58,9 +58,9 @@ def get_litellm_params(
async_call: Optional[bool] = None, async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None, ssl_verify: Optional[bool] = None,
merge_reasoning_content_in_choices: Optional[bool] = None, merge_reasoning_content_in_choices: Optional[bool] = None,
max_retries: Optional[int] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
"api_key": api_key, "api_key": api_key,
@ -106,5 +106,7 @@ def get_litellm_params(
"client_secret": kwargs.get("client_secret"), "client_secret": kwargs.get("client_secret"),
"azure_username": kwargs.get("azure_username"), "azure_username": kwargs.get("azure_username"),
"azure_password": kwargs.get("azure_password"), "azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
} }
return litellm_params return litellm_params

View file

@ -718,6 +718,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
): ):
response = None response = None
try: try:
if client is None: if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params) openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else: else:

View file

@ -342,6 +342,9 @@ class BaseAzureLLM:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(azure_client_params)
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
return azure_client_params return azure_client_params

View file

@ -1168,6 +1168,8 @@ def completion( # type: ignore # noqa: PLR0915
client_secret=kwargs.get("client_secret"), client_secret=kwargs.get("client_secret"),
azure_username=kwargs.get("azure_username"), azure_username=kwargs.get("azure_username"),
azure_password=kwargs.get("azure_password"), azure_password=kwargs.get("azure_password"),
max_retries=max_retries,
timeout=timeout,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -3356,6 +3358,7 @@ def embedding( # noqa: PLR0915
} }
} }
) )
litellm_params_dict = get_litellm_params(**kwargs) litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj # type: ignore logging: Logging = litellm_logging_obj # type: ignore

View file

@ -556,12 +556,11 @@ async def test_azure_instruct(
@pytest.mark.parametrize("max_retries", [0, 4]) @pytest.mark.parametrize("max_retries", [0, 4])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint") @patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_embedding_max_retries_0( async def test_azure_embedding_max_retries_0(
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode mock_select_azure_base_url_or_endpoint, max_retries, sync_mode
): ):
from litellm import aembedding, embedding from litellm import aembedding, embedding
@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0(
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"input": "Hello world", "input": "Hello world",
"max_retries": max_retries, "max_retries": max_retries,
"stream": stream,
} }
try: try:
@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0(
print(e) print(e)
mock_select_azure_base_url_or_endpoint.assert_called_once() mock_select_azure_base_url_or_endpoint.assert_called_once()
print(
"mock_select_azure_base_url_or_endpoint.call_args.kwargs",
mock_select_azure_base_url_or_endpoint.call_args.kwargs,
)
assert ( assert (
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][ mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
"max_retries" "max_retries"