fix(router.py): fix router retry policy logic

This commit is contained in:
Krrish Dholakia 2024-05-04 23:02:50 -07:00
parent 0529d7eaa3
commit e95be13f10
2 changed files with 9 additions and 39 deletions

View file

@ -1472,49 +1472,14 @@ class Router:
):
raise original_exception
### RETRY
#### check if it should retry + back-off if required
# if "No models available" in str(
# e
# ) or RouterErrors.no_deployments_available.value in str(e):
# timeout = litellm._calculate_retry_after(
# remaining_retries=num_retries,
# max_retries=num_retries,
# min_timeout=self.retry_after,
# )
# await asyncio.sleep(timeout)
# elif RouterErrors.user_defined_ratelimit_error.value in str(e):
# raise e # don't wait to retry if deployment hits user-defined rate-limit
# elif hasattr(original_exception, "status_code") and litellm._should_retry(
# status_code=original_exception.status_code
# ):
# if hasattr(original_exception, "response") and hasattr(
# original_exception.response, "headers"
# ):
# timeout = litellm._calculate_retry_after(
# remaining_retries=num_retries,
# max_retries=num_retries,
# response_headers=original_exception.response.headers,
# min_timeout=self.retry_after,
# )
# else:
# timeout = litellm._calculate_retry_after(
# remaining_retries=num_retries,
# max_retries=num_retries,
# min_timeout=self.retry_after,
# )
# await asyncio.sleep(timeout)
# else:
# raise original_exception
### RETRY
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
## LOGGING
if (
self.retry_policy is not None
or self.model_group_retry_policy is not None
@ -1525,7 +1490,7 @@ class Router:
)
if _retry_policy_retries is not None:
num_retries = _retry_policy_retries
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
@ -3292,7 +3257,11 @@ class Router:
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy
if self.model_group_retry_policy is not None and model_group is not None:
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
retry_policy = self.model_group_retry_policy.get(model_group, None)
if retry_policy is None:

View file

@ -123,7 +123,8 @@ async def test_router_retries_errors(sync_mode, error_type):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_type", ["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"]
"error_type",
["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], #
)
async def test_router_retry_policy(error_type):
from litellm.router import RetryPolicy