diff --git a/litellm/router.py b/litellm/router.py index 8ea1a124a..f173d52fb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1418,6 +1418,13 @@ class Router: traceback.print_exc() raise original_exception + async def _async_router_should_retry( + self, e: Exception, remaining_retries: int, num_retries: int + ): + """ + Calculate back-off, then retry + """ + async def async_function_with_retries(self, *args, **kwargs): verbose_router_logger.debug( f"Inside async function with retries: args - {args}; kwargs - {kwargs}" @@ -1450,40 +1457,47 @@ class Router: raise original_exception ### RETRY #### check if it should retry + back-off if required - if "No models available" in str( - e - ) or RouterErrors.no_deployments_available.value in str(e): - timeout = litellm._calculate_retry_after( - remaining_retries=num_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - elif RouterErrors.user_defined_ratelimit_error.value in str(e): - raise e # don't wait to retry if deployment hits user-defined rate-limit + # if "No models available" in str( + # e + # ) or RouterErrors.no_deployments_available.value in str(e): + # timeout = litellm._calculate_retry_after( + # remaining_retries=num_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + # await asyncio.sleep(timeout) + # elif RouterErrors.user_defined_ratelimit_error.value in str(e): + # raise e # don't wait to retry if deployment hits user-defined rate-limit - elif hasattr(original_exception, "status_code") and litellm._should_retry( - status_code=original_exception.status_code - ): - if hasattr(original_exception, "response") and hasattr( - original_exception.response, "headers" - ): - timeout = litellm._calculate_retry_after( - remaining_retries=num_retries, - max_retries=num_retries, - response_headers=original_exception.response.headers, - min_timeout=self.retry_after, - ) - else: - timeout = litellm._calculate_retry_after( - remaining_retries=num_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - else: - raise original_exception + # elif hasattr(original_exception, "status_code") and litellm._should_retry( + # status_code=original_exception.status_code + # ): + # if hasattr(original_exception, "response") and hasattr( + # original_exception.response, "headers" + # ): + # timeout = litellm._calculate_retry_after( + # remaining_retries=num_retries, + # max_retries=num_retries, + # response_headers=original_exception.response.headers, + # min_timeout=self.retry_after, + # ) + # else: + # timeout = litellm._calculate_retry_after( + # remaining_retries=num_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + # await asyncio.sleep(timeout) + # else: + # raise original_exception + ### RETRY + _timeout = self._router_should_retry( + e=original_exception, + remaining_retries=num_retries, + num_retries=num_retries, + ) + await asyncio.sleep(_timeout) ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) @@ -1505,34 +1519,37 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - if "No models available" in str(e): - timeout = litellm._calculate_retry_after( - remaining_retries=remaining_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - elif ( - hasattr(e, "status_code") - and hasattr(e, "response") - and litellm._should_retry(status_code=e.status_code) - ): - if hasattr(e.response, "headers"): - timeout = litellm._calculate_retry_after( - remaining_retries=remaining_retries, - max_retries=num_retries, - response_headers=e.response.headers, - min_timeout=self.retry_after, - ) - else: - timeout = litellm._calculate_retry_after( - remaining_retries=remaining_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - else: - raise e + # if "No models available" in str(e): + # timeout = litellm._calculate_retry_after( + # remaining_retries=remaining_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + # await asyncio.sleep(timeout) + # elif ( + # hasattr(e, "status_code") + # and hasattr(e, "response") + # and litellm._should_retry(status_code=e.status_code) + # ): + # if hasattr(e.response, "headers"): + # timeout = litellm._calculate_retry_after( + # remaining_retries=remaining_retries, + # max_retries=num_retries, + # response_headers=e.response.headers, + # min_timeout=self.retry_after, + # ) + # else: + # timeout = litellm._calculate_retry_after( + # remaining_retries=remaining_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + _timeout = self._router_should_retry( + e=original_exception, + remaining_retries=remaining_retries, + num_retries=num_retries, + ) + await asyncio.sleep(_timeout) raise original_exception def function_with_fallbacks(self, *args, **kwargs): @@ -1625,7 +1642,7 @@ class Router: def _router_should_retry( self, e: Exception, remaining_retries: int, num_retries: int - ): + ) -> int | float: """ Calculate back-off, then retry """ @@ -1636,14 +1653,13 @@ class Router: response_headers=e.response.headers, min_timeout=self.retry_after, ) - time.sleep(timeout) else: timeout = litellm._calculate_retry_after( remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=self.retry_after, ) - time.sleep(timeout) + return timeout def function_with_retries(self, *args, **kwargs): """ @@ -1677,11 +1693,12 @@ class Router: if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) ### RETRY - self._router_should_retry( + _timeout = self._router_should_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, ) + time.sleep(_timeout) for current_attempt in range(num_retries): verbose_router_logger.debug( f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" @@ -1695,11 +1712,12 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - self._router_should_retry( + _timeout = self._router_should_retry( e=e, remaining_retries=remaining_retries, num_retries=num_retries, ) + time.sleep(_timeout) raise original_exception ### HELPER FUNCTIONS diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 7520ac75f..8c6b9fa01 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -104,6 +104,42 @@ def test_router_timeout_init(timeout, ssl_verify): ) +@pytest.mark.parametrize("sync_mode", [False, True]) +@pytest.mark.asyncio +async def test_router_retries(sync_mode): + """ + - make sure retries work as expected + """ + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "bad-key"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + }, + ] + + router = Router(model_list=model_list, num_retries=2) + + if sync_mode: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + else: + await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + + @pytest.mark.parametrize( "mistral_api_base", [