diff --git a/litellm/router.py b/litellm/router.py index 3b1c1d102..46830d9ed 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,6 +86,9 @@ class Router: retry_policy: Optional[ RetryPolicy ] = None, # set custom retries for different exceptions + model_group_retry_policy: Optional[ + dict[str, RetryPolicy] + ] = {}, # set custom retry policies based on model group allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown @@ -308,6 +311,9 @@ class Router: ) # noqa self.routing_strategy_args = routing_strategy_args self.retry_policy: Optional[RetryPolicy] = retry_policy + self.model_group_retry_policy: Optional[dict[str, RetryPolicy]] = ( + model_group_retry_policy + ) def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): if routing_strategy == "least-busy": @@ -1509,11 +1515,13 @@ class Router: ) await asyncio.sleep(_timeout) ## LOGGING - if self.retry_policy is not None or kwargs.get("retry_policy") is not None: + if ( + self.retry_policy is not None + or self.model_group_retry_policy is not None + ): # get num_retries from retry policy _retry_policy_retries = self.get_num_retries_from_retry_policy( - exception=original_exception, - dynamic_retry_policy=kwargs.get("retry_policy"), + exception=original_exception, model_group=kwargs.get("model") ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries @@ -3269,7 +3277,7 @@ class Router: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") def get_num_retries_from_retry_policy( - self, exception: Exception, dynamic_retry_policy: Optional[RetryPolicy] = None + self, exception: Exception, model_group: Optional[str] = None ): """ BadRequestErrorRetries: Optional[int] = None @@ -3280,8 +3288,9 @@ class Router: """ # if we can find the exception then in the retry policy -> return the number of retries retry_policy = self.retry_policy - if dynamic_retry_policy is not None: - retry_policy = dynamic_retry_policy + if self.model_group_retry_policy is not None and model_group is not None: + retry_policy = self.model_group_retry_policy.get(model_group, None) + if retry_policy is None: return None if (