From e95be13f101e8f0006fd23aa2b9f28d35102b63d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 23:02:50 -0700 Subject: [PATCH] fix(router.py): fix router retry policy logic --- litellm/router.py | 45 +++++----------------------- litellm/tests/test_router_retries.py | 3 +- 2 files changed, 9 insertions(+), 39 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 49699ffed..fbb245a3d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1472,49 +1472,14 @@ class Router: ): raise original_exception ### RETRY - #### check if it should retry + back-off if required - # if "No models available" in str( - # e - # ) or RouterErrors.no_deployments_available.value in str(e): - # timeout = litellm._calculate_retry_after( - # remaining_retries=num_retries, - # max_retries=num_retries, - # min_timeout=self.retry_after, - # ) - # await asyncio.sleep(timeout) - # elif RouterErrors.user_defined_ratelimit_error.value in str(e): - # raise e # don't wait to retry if deployment hits user-defined rate-limit - # elif hasattr(original_exception, "status_code") and litellm._should_retry( - # status_code=original_exception.status_code - # ): - # if hasattr(original_exception, "response") and hasattr( - # original_exception.response, "headers" - # ): - # timeout = litellm._calculate_retry_after( - # remaining_retries=num_retries, - # max_retries=num_retries, - # response_headers=original_exception.response.headers, - # min_timeout=self.retry_after, - # ) - # else: - # timeout = litellm._calculate_retry_after( - # remaining_retries=num_retries, - # max_retries=num_retries, - # min_timeout=self.retry_after, - # ) - # await asyncio.sleep(timeout) - # else: - # raise original_exception - - ### RETRY _timeout = self._router_should_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, ) await asyncio.sleep(_timeout) - ## LOGGING + if ( self.retry_policy is not None or self.model_group_retry_policy is not None @@ -1525,7 +1490,7 @@ class Router: ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries - + ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) @@ -3292,7 +3257,11 @@ class Router: """ # if we can find the exception then in the retry policy -> return the number of retries retry_policy = self.retry_policy - if self.model_group_retry_policy is not None and model_group is not None: + if ( + self.model_group_retry_policy is not None + and model_group is not None + and model_group in self.model_group_retry_policy + ): retry_policy = self.model_group_retry_policy.get(model_group, None) if retry_policy is None: diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index 0ca566cae..6ae3c2c45 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -123,7 +123,8 @@ async def test_router_retries_errors(sync_mode, error_type): @pytest.mark.asyncio @pytest.mark.parametrize( - "error_type", ["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"] + "error_type", + ["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], # ) async def test_router_retry_policy(error_type): from litellm.router import RetryPolicy