fix - _time_to_sleep_before_retry

This commit is contained in:
Ishaan Jaff 2024-05-11 19:08:10 -07:00
parent 4d0f525d56
commit ffdf68d7e8

View file

@ -1512,7 +1512,7 @@ class Router:
Retry Logic Retry Logic
""" """
_, _healthy_deployments = self._common_checks_available_deployment( _healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
@ -1520,7 +1520,6 @@ class Router:
self.should_retry_this_error( self.should_retry_this_error(
error=e, error=e,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
fallbacks=fallbacks,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
) )
@ -1530,7 +1529,6 @@ class Router:
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
_healthy_deployments=_healthy_deployments, _healthy_deployments=_healthy_deployments,
fallbacks=fallbacks,
) )
# sleeps for the length of the timeout # sleeps for the length of the timeout
@ -1575,7 +1573,6 @@ class Router:
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
fallbacks=fallbacks,
) )
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
try: try:
@ -1588,7 +1585,6 @@ class Router:
self, self,
error: Exception, error: Exception,
healthy_deployments: Optional[List] = None, healthy_deployments: Optional[List] = None,
fallbacks: Optional[List] = None,
context_window_fallbacks: Optional[List] = None, context_window_fallbacks: Optional[List] = None,
): ):
""" """
@ -1604,15 +1600,17 @@ class Router:
_num_healthy_deployments = len(healthy_deployments) _num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if ( if (
isinstance(error, litellm.ContextWindowExceededError) isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is None and context_window_fallbacks is None
): ):
raise error raise error
if isinstance(error, openai.RateLimitError): # Error we should only retry if there are other deployments
if fallbacks is None and _num_healthy_deployments <= 0: if isinstance(error, openai.RateLimitError) or isinstance(
error, openai.AuthenticationError
):
if _num_healthy_deployments <= 0:
raise error raise error
return True return True
@ -1711,7 +1709,6 @@ class Router:
remaining_retries: int, remaining_retries: int,
num_retries: int, num_retries: int,
healthy_deployments: Optional[List] = None, healthy_deployments: Optional[List] = None,
fallbacks: Optional[List] = None,
) -> Union[int, float]: ) -> Union[int, float]:
""" """
Calculate back-off, then retry Calculate back-off, then retry
@ -1727,9 +1724,6 @@ class Router:
): ):
return 0 return 0
if fallbacks is not None and isinstance(fallbacks, list) and len(fallbacks) > 0:
return 0
if hasattr(e, "response") and hasattr(e.response, "headers"): if hasattr(e, "response") and hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
@ -1766,7 +1760,7 @@ class Router:
except Exception as e: except Exception as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_, _healthy_deployments = self._common_checks_available_deployment( _healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
@ -1774,7 +1768,6 @@ class Router:
self.should_retry_this_error( self.should_retry_this_error(
error=e, error=e,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
fallbacks=fallbacks,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
) )
@ -1784,7 +1777,6 @@ class Router:
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
_healthy_deployments=_healthy_deployments, _healthy_deployments=_healthy_deployments,
fallbacks=fallbacks,
) )
## LOGGING ## LOGGING
@ -1813,7 +1805,6 @@ class Router:
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
fallbacks=fallbacks,
) )
time.sleep(_timeout) time.sleep(_timeout)
raise original_exception raise original_exception
@ -2016,6 +2007,35 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models return cooldown_models
def _get_healthy_deployments(self, model: str):
_, _all_deployments = self._common_checks_available_deployment(
model=model,
)
unhealthy_deployments = self._get_cooldown_deployments()
healthy_deployments = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
async def _async_get_healthy_deployments(self, model: str):
_, _all_deployments = self._common_checks_available_deployment(
model=model,
)
unhealthy_deployments = await self._async_get_cooldown_deployments()
healthy_deployments = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
def routing_strategy_pre_call_checks(self, deployment: dict): def routing_strategy_pre_call_checks(self, deployment: dict):
""" """
Mimics 'async_routing_strategy_pre_call_checks' Mimics 'async_routing_strategy_pre_call_checks'