diff --git a/litellm/router.py b/litellm/router.py index 48a970319..d4d7dd2c1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1524,10 +1524,12 @@ class Router: context_window_fallbacks=context_window_fallbacks, ) - _timeout = self._router_should_retry( + _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, + _healthy_deployments=_healthy_deployments, + fallbacks=fallbacks, ) ### RETRY @@ -1564,7 +1566,7 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - _timeout = self._router_should_retry( + _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=remaining_retries, num_retries=num_retries, @@ -1697,12 +1699,31 @@ class Router: raise e raise original_exception - def _router_should_retry( - self, e: Exception, remaining_retries: int, num_retries: int + def _time_to_sleep_before_retry( + self, + e: Exception, + remaining_retries: int, + num_retries: int, + healthy_deployments: Optional[List] = None, + fallbacks: Optional[List] = None, ) -> Union[int, float]: """ Calculate back-off, then retry + + It should instantly retry only when: + 1. there are healthy deployments in the same model group + 2. there are fallbacks for the completion call """ + if ( + healthy_deployments is not None + and isinstance(healthy_deployments, list) + and len(healthy_deployments) > 0 + ): + return 0 + + if fallbacks is not None and isinstance(fallbacks, list) and len(fallbacks) > 0: + return 0 + if hasattr(e, "response") and hasattr(e.response, "headers"): timeout = litellm._calculate_retry_after( remaining_retries=remaining_retries, @@ -1751,7 +1772,7 @@ class Router: if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) ### RETRY - _timeout = self._router_should_retry( + _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, @@ -1770,7 +1791,7 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - _timeout = self._router_should_retry( + _timeout = self._time_to_sleep_before_retry( e=e, remaining_retries=remaining_retries, num_retries=num_retries,