mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
test: fix tests
This commit is contained in:
parent
cbc2e84044
commit
9af73f339a
5 changed files with 17 additions and 6 deletions
|
@ -58,9 +58,9 @@ def get_litellm_params(
|
||||||
async_call: Optional[bool] = None,
|
async_call: Optional[bool] = None,
|
||||||
ssl_verify: Optional[bool] = None,
|
ssl_verify: Optional[bool] = None,
|
||||||
merge_reasoning_content_in_choices: Optional[bool] = None,
|
merge_reasoning_content_in_choices: Optional[bool] = None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
@ -106,5 +106,7 @@ def get_litellm_params(
|
||||||
"client_secret": kwargs.get("client_secret"),
|
"client_secret": kwargs.get("client_secret"),
|
||||||
"azure_username": kwargs.get("azure_username"),
|
"azure_username": kwargs.get("azure_username"),
|
||||||
"azure_password": kwargs.get("azure_password"),
|
"azure_password": kwargs.get("azure_password"),
|
||||||
|
"max_retries": max_retries,
|
||||||
|
"timeout": kwargs.get("timeout"),
|
||||||
}
|
}
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
|
@ -718,6 +718,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -342,6 +342,9 @@ class BaseAzureLLM:
|
||||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||||
azure_client_params = select_azure_base_url_or_endpoint(azure_client_params)
|
|
||||||
|
azure_client_params = select_azure_base_url_or_endpoint(
|
||||||
|
azure_client_params=azure_client_params
|
||||||
|
)
|
||||||
|
|
||||||
return azure_client_params
|
return azure_client_params
|
||||||
|
|
|
@ -1168,6 +1168,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
client_secret=kwargs.get("client_secret"),
|
client_secret=kwargs.get("client_secret"),
|
||||||
azure_username=kwargs.get("azure_username"),
|
azure_username=kwargs.get("azure_username"),
|
||||||
azure_password=kwargs.get("azure_password"),
|
azure_password=kwargs.get("azure_password"),
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3356,6 +3358,7 @@ def embedding( # noqa: PLR0915
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
litellm_params_dict = get_litellm_params(**kwargs)
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
logging: Logging = litellm_logging_obj # type: ignore
|
logging: Logging = litellm_logging_obj # type: ignore
|
||||||
|
|
|
@ -556,12 +556,11 @@ async def test_azure_instruct(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("max_retries", [0, 4])
|
@pytest.mark.parametrize("max_retries", [0, 4])
|
||||||
@pytest.mark.parametrize("stream", [True, False])
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint")
|
@patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_azure_embedding_max_retries_0(
|
async def test_azure_embedding_max_retries_0(
|
||||||
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
|
mock_select_azure_base_url_or_endpoint, max_retries, sync_mode
|
||||||
):
|
):
|
||||||
from litellm import aembedding, embedding
|
from litellm import aembedding, embedding
|
||||||
|
|
||||||
|
@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0(
|
||||||
"model": "azure/azure-embedding-model",
|
"model": "azure/azure-embedding-model",
|
||||||
"input": "Hello world",
|
"input": "Hello world",
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"stream": stream,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0(
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
mock_select_azure_base_url_or_endpoint.assert_called_once()
|
mock_select_azure_base_url_or_endpoint.assert_called_once()
|
||||||
|
print(
|
||||||
|
"mock_select_azure_base_url_or_endpoint.call_args.kwargs",
|
||||||
|
mock_select_azure_base_url_or_endpoint.call_args.kwargs,
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
|
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
|
||||||
"max_retries"
|
"max_retries"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue