From 821844c1a3a2662b04a39c27b5a49155d146b41f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 24 Apr 2024 22:02:48 -0700 Subject: [PATCH] fix(router.py): fix max retries on set_client --- litellm/router.py | 15 ++++++++++++--- litellm/tests/test_router.py | 6 ++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 371d8e8ebd..15a82cb090 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -70,7 +70,7 @@ class Router: ] = None, # if you want to cache across model groups client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds ## RELIABILITY ## - num_retries: int = 0, + num_retries: Optional[int] = None, timeout: Optional[float] = None, default_litellm_params={}, # default params for Router.chat.completion.create default_max_parallel_requests: Optional[int] = None, @@ -229,7 +229,12 @@ class Router: self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown - self.num_retries = num_retries or litellm.num_retries or 0 + self.num_retries = num_retries # type: ignore + if self.num_retries is None: + if litellm.num_retries is not None: + self.num_retries = litellm.num_retries + else: + self.num_retries = openai.DEFAULT_MAX_RETRIES self.timeout = timeout or litellm.request_timeout self.retry_after = retry_after @@ -1986,7 +1991,7 @@ class Router: stream_timeout = litellm.get_secret(stream_timeout_env_name) litellm_params["stream_timeout"] = stream_timeout - max_retries = litellm_params.pop("max_retries", 2) + max_retries = litellm_params.pop("max_retries", self.num_retries) if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): max_retries_env_name = max_retries.replace("os.environ/", "") max_retries = litellm.get_secret(max_retries_env_name) @@ -2883,6 +2888,10 @@ class Router: model=model, healthy_deployments=healthy_deployments, messages=messages ) + if len(healthy_deployments) == 0: + raise ValueError( + f"No deployments available for selected model, passed model={model}" + ) if ( self.routing_strategy == "usage-based-routing-v2" and self.lowesttpm_logger_v2 is not None diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 7beb1d67c7..ac47bbc21c 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -18,6 +18,12 @@ from dotenv import load_dotenv load_dotenv() +def test_router_num_retries_init(): + router = Router(num_retries=0) + + assert router.num_retries == 0 + + def test_exception_raising(): # this tests if the router raises an exception when invalid params are set # in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception