diff --git a/litellm/router.py b/litellm/router.py index 23000d9575..e330bdd9e7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1512,7 +1512,7 @@ class Router: Retry Logic """ - _, _healthy_deployments = self._common_checks_available_deployment( + _healthy_deployments = await self._async_get_healthy_deployments( model=kwargs.get("model"), ) @@ -1520,7 +1520,6 @@ class Router: self.should_retry_this_error( error=e, healthy_deployments=_healthy_deployments, - fallbacks=fallbacks, context_window_fallbacks=context_window_fallbacks, ) @@ -1530,7 +1529,6 @@ class Router: remaining_retries=num_retries, num_retries=num_retries, _healthy_deployments=_healthy_deployments, - fallbacks=fallbacks, ) # sleeps for the length of the timeout @@ -1575,7 +1573,6 @@ class Router: remaining_retries=remaining_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, - fallbacks=fallbacks, ) await asyncio.sleep(_timeout) try: @@ -1588,7 +1585,6 @@ class Router: self, error: Exception, healthy_deployments: Optional[List] = None, - fallbacks: Optional[List] = None, context_window_fallbacks: Optional[List] = None, ): """ @@ -1604,15 +1600,17 @@ class Router: _num_healthy_deployments = len(healthy_deployments) ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error - if ( isinstance(error, litellm.ContextWindowExceededError) and context_window_fallbacks is None ): raise error - if isinstance(error, openai.RateLimitError): - if fallbacks is None and _num_healthy_deployments <= 0: + # Error we should only retry if there are other deployments + if isinstance(error, openai.RateLimitError) or isinstance( + error, openai.AuthenticationError + ): + if _num_healthy_deployments <= 0: raise error return True @@ -1711,7 +1709,6 @@ class Router: remaining_retries: int, num_retries: int, healthy_deployments: Optional[List] = None, - fallbacks: Optional[List] = None, ) -> Union[int, float]: """ Calculate back-off, then retry @@ -1727,9 +1724,6 @@ class Router: ): return 0 - if fallbacks is not None and isinstance(fallbacks, list) and len(fallbacks) > 0: - return 0 - if hasattr(e, "response") and hasattr(e.response, "headers"): timeout = litellm._calculate_retry_after( remaining_retries=remaining_retries, @@ -1766,7 +1760,7 @@ class Router: except Exception as e: original_exception = e ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR - _, _healthy_deployments = self._common_checks_available_deployment( + _healthy_deployments = self._get_healthy_deployments( model=kwargs.get("model"), ) @@ -1774,7 +1768,6 @@ class Router: self.should_retry_this_error( error=e, healthy_deployments=_healthy_deployments, - fallbacks=fallbacks, context_window_fallbacks=context_window_fallbacks, ) @@ -1784,7 +1777,6 @@ class Router: remaining_retries=num_retries, num_retries=num_retries, _healthy_deployments=_healthy_deployments, - fallbacks=fallbacks, ) ## LOGGING @@ -1813,7 +1805,6 @@ class Router: remaining_retries=remaining_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, - fallbacks=fallbacks, ) time.sleep(_timeout) raise original_exception @@ -2016,6 +2007,35 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models + def _get_healthy_deployments(self, model: str): + _, _all_deployments = self._common_checks_available_deployment( + model=model, + ) + + unhealthy_deployments = self._get_cooldown_deployments() + healthy_deployments = [] + for deployment in _all_deployments: + if deployment["model_info"]["id"] in unhealthy_deployments: + continue + else: + healthy_deployments.append(deployment) + + return healthy_deployments + + async def _async_get_healthy_deployments(self, model: str): + _, _all_deployments = self._common_checks_available_deployment( + model=model, + ) + + unhealthy_deployments = await self._async_get_cooldown_deployments() + healthy_deployments = [] + for deployment in _all_deployments: + if deployment["model_info"]["id"] in unhealthy_deployments: + continue + else: + healthy_deployments.append(deployment) + return healthy_deployments + def routing_strategy_pre_call_checks(self, deployment: dict): """ Mimics 'async_routing_strategy_pre_call_checks'