mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test: fix tests
This commit is contained in:
parent
2f262ed9b4
commit
934c06c207
5 changed files with 17 additions and 6 deletions
|
@ -58,9 +58,9 @@ def get_litellm_params(
|
|||
async_call: Optional[bool] = None,
|
||||
ssl_verify: Optional[bool] = None,
|
||||
merge_reasoning_content_in_choices: Optional[bool] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
"api_key": api_key,
|
||||
|
@ -106,5 +106,7 @@ def get_litellm_params(
|
|||
"client_secret": kwargs.get("client_secret"),
|
||||
"azure_username": kwargs.get("azure_username"),
|
||||
"azure_password": kwargs.get("azure_password"),
|
||||
"max_retries": max_retries,
|
||||
"timeout": kwargs.get("timeout"),
|
||||
}
|
||||
return litellm_params
|
||||
|
|
|
@ -718,6 +718,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
):
|
||||
response = None
|
||||
try:
|
||||
|
||||
if client is None:
|
||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
|
|
|
@ -342,6 +342,9 @@ class BaseAzureLLM:
|
|||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||
azure_client_params = select_azure_base_url_or_endpoint(azure_client_params)
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
||||
return azure_client_params
|
||||
|
|
|
@ -1168,6 +1168,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
client_secret=kwargs.get("client_secret"),
|
||||
azure_username=kwargs.get("azure_username"),
|
||||
azure_password=kwargs.get("azure_password"),
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -3356,6 +3358,7 @@ def embedding( # noqa: PLR0915
|
|||
}
|
||||
}
|
||||
)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
logging: Logging = litellm_logging_obj # type: ignore
|
||||
|
|
|
@ -556,12 +556,11 @@ async def test_azure_instruct(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("max_retries", [0, 4])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint")
|
||||
@patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint")
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_embedding_max_retries_0(
|
||||
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
|
||||
mock_select_azure_base_url_or_endpoint, max_retries, sync_mode
|
||||
):
|
||||
from litellm import aembedding, embedding
|
||||
|
||||
|
@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0(
|
|||
"model": "azure/azure-embedding-model",
|
||||
"input": "Hello world",
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
try:
|
||||
|
@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0(
|
|||
print(e)
|
||||
|
||||
mock_select_azure_base_url_or_endpoint.assert_called_once()
|
||||
print(
|
||||
"mock_select_azure_base_url_or_endpoint.call_args.kwargs",
|
||||
mock_select_azure_base_url_or_endpoint.call_args.kwargs,
|
||||
)
|
||||
assert (
|
||||
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
|
||||
"max_retries"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue