mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): ensure no unsupported args are passed to completion()
This commit is contained in:
parent
f19f0dad89
commit
53b879bc6c
2 changed files with 7 additions and 7 deletions
|
@ -1528,6 +1528,9 @@ class Router:
|
|||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
max_retries = int(max_retries)
|
||||
litellm_params[
|
||||
"max_retries"
|
||||
] = max_retries # do this for testing purposes
|
||||
|
||||
if "azure" in model_name:
|
||||
if api_base is None:
|
||||
|
|
|
@ -783,9 +783,6 @@ def test_reading_keys_os_environ():
|
|||
assert float(model["litellm_params"]["timeout"]) == float(
|
||||
os.environ["AZURE_TIMEOUT"]
|
||||
), f"{model['litellm_params']['timeout']} vs {os.environ['AZURE_TIMEOUT']}"
|
||||
assert float(model["litellm_params"]["stream_timeout"]) == float(
|
||||
os.environ["AZURE_STREAM_TIMEOUT"]
|
||||
), f"{model['litellm_params']['stream_timeout']} vs {os.environ['AZURE_STREAM_TIMEOUT']}"
|
||||
assert int(model["litellm_params"]["max_retries"]) == int(
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
|
@ -794,7 +791,7 @@ def test_reading_keys_os_environ():
|
|||
async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore
|
||||
assert async_client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert async_client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert async_client.max_retries == (
|
||||
assert async_client.max_retries == int(
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
assert async_client.timeout == (
|
||||
|
@ -807,7 +804,7 @@ def test_reading_keys_os_environ():
|
|||
stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore
|
||||
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert stream_async_client.max_retries == (
|
||||
assert stream_async_client.max_retries == int(
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{stream_async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
assert stream_async_client.timeout == (
|
||||
|
@ -819,7 +816,7 @@ def test_reading_keys_os_environ():
|
|||
client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore
|
||||
assert client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert client.max_retries == (
|
||||
assert client.max_retries == int(
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
assert client.timeout == (
|
||||
|
@ -831,7 +828,7 @@ def test_reading_keys_os_environ():
|
|||
stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore
|
||||
assert stream_client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert stream_client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert stream_client.max_retries == (
|
||||
assert stream_client.max_retries == int(
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{stream_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
assert stream_client.timeout == (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue