Merge pull request #3376 from BerriAI/litellm_routing_logic

fix(router.py): unify retry timeout logic across sync + async function_with_retries
This commit is contained in:
Krish Dholakia 2024-04-30 19:58:45 -07:00 committed by GitHub
commit 9f55a99e98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 265 additions and 82 deletions

View file

@ -1450,40 +1450,47 @@ class Router:
raise original_exception
### RETRY
#### check if it should retry + back-off if required
if "No models available" in str(
e
) or RouterErrors.no_deployments_available.value in str(e):
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
# if "No models available" in str(
# e
# ) or RouterErrors.no_deployments_available.value in str(e):
# timeout = litellm._calculate_retry_after(
# remaining_retries=num_retries,
# max_retries=num_retries,
# min_timeout=self.retry_after,
# )
# await asyncio.sleep(timeout)
# elif RouterErrors.user_defined_ratelimit_error.value in str(e):
# raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
if hasattr(original_exception, "response") and hasattr(
original_exception.response, "headers"
):
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,
response_headers=original_exception.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
else:
raise original_exception
# elif hasattr(original_exception, "status_code") and litellm._should_retry(
# status_code=original_exception.status_code
# ):
# if hasattr(original_exception, "response") and hasattr(
# original_exception.response, "headers"
# ):
# timeout = litellm._calculate_retry_after(
# remaining_retries=num_retries,
# max_retries=num_retries,
# response_headers=original_exception.response.headers,
# min_timeout=self.retry_after,
# )
# else:
# timeout = litellm._calculate_retry_after(
# remaining_retries=num_retries,
# max_retries=num_retries,
# min_timeout=self.retry_after,
# )
# await asyncio.sleep(timeout)
# else:
# raise original_exception
### RETRY
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
@ -1505,34 +1512,12 @@ class Router:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
if "No models available" in str(e):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif (
hasattr(e, "status_code")
and hasattr(e, "response")
and litellm._should_retry(status_code=e.status_code)
):
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
else:
raise e
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=remaining_retries,
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
raise original_exception
def function_with_fallbacks(self, *args, **kwargs):
@ -1625,7 +1610,7 @@ class Router:
def _router_should_retry(
self, e: Exception, remaining_retries: int, num_retries: int
):
) -> Union[int, float]:
"""
Calculate back-off, then retry
"""
@ -1636,14 +1621,13 @@ class Router:
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
time.sleep(timeout)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
time.sleep(timeout)
return timeout
def function_with_retries(self, *args, **kwargs):
"""
@ -1658,6 +1642,7 @@ class Router:
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
@ -1677,11 +1662,12 @@ class Router:
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY
self._router_should_retry(
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
time.sleep(_timeout)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
@ -1695,11 +1681,12 @@ class Router:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
self._router_should_retry(
_timeout = self._router_should_retry(
e=e,
remaining_retries=remaining_retries,
num_retries=num_retries,
)
time.sleep(_timeout)
raise original_exception
### HELPER FUNCTIONS
@ -1733,10 +1720,11 @@ class Router:
) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
if isinstance(_model_info, dict):
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
deployment_id
exception_status=exception_status, deployment=deployment_id
) # setting deployment_id in cooldown deployments
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
@ -1796,9 +1784,15 @@ class Router:
key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl
def _set_cooldown_deployments(self, deployment: Optional[str] = None):
def _set_cooldown_deployments(
self, exception_status: Union[str, int], deployment: Optional[str] = None
):
"""
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or
the exception is not one that should be immediately retried (e.g. 401)
"""
if deployment is None:
return
@ -1815,7 +1809,20 @@ class Router:
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if updated_fails > self.allowed_fails:
if isinstance(exception_status, str):
try:
exception_status = int(exception_status)
except Exception as e:
verbose_router_logger.debug(
"Unable to cast exception status to int {}. Defaulting to status=500.".format(
exception_status
)
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > self.allowed_fails or _should_retry == False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)