diff --git a/litellm/main.py b/litellm/main.py index 59d98580cf..d19463f532 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -620,6 +620,7 @@ def completion( "model_list", "num_retries", "context_window_fallback_dict", + "retry_policy", "roles", "final_prompt_value", "bos_token", @@ -2687,6 +2688,7 @@ def embedding( "model_list", "num_retries", "context_window_fallback_dict", + "retry_policy", "roles", "final_prompt_value", "bos_token", @@ -3556,6 +3558,7 @@ def image_generation( "model_list", "num_retries", "context_window_fallback_dict", + "retry_policy", "roles", "final_prompt_value", "bos_token", diff --git a/litellm/router.py b/litellm/router.py index 258f50457a..3b1c1d1022 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1509,10 +1509,11 @@ class Router: ) await asyncio.sleep(_timeout) ## LOGGING - if self.retry_policy is not None: + if self.retry_policy is not None or kwargs.get("retry_policy") is not None: # get num_retries from retry policy _retry_policy_retries = self.get_num_retries_from_retry_policy( - exception=original_exception + exception=original_exception, + dynamic_retry_policy=kwargs.get("retry_policy"), ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries @@ -3267,7 +3268,9 @@ class Router: except Exception as e: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") - def get_num_retries_from_retry_policy(self, exception: Exception): + def get_num_retries_from_retry_policy( + self, exception: Exception, dynamic_retry_policy: Optional[RetryPolicy] = None + ): """ BadRequestErrorRetries: Optional[int] = None AuthenticationErrorRetries: Optional[int] = None @@ -3276,33 +3279,36 @@ class Router: ContentPolicyViolationErrorRetries: Optional[int] = None """ # if we can find the exception then in the retry policy -> return the number of retries - if self.retry_policy is None: + retry_policy = self.retry_policy + if dynamic_retry_policy is not None: + retry_policy = dynamic_retry_policy + if retry_policy is None: return None if ( isinstance(exception, litellm.BadRequestError) - and self.retry_policy.BadRequestErrorRetries is not None + and retry_policy.BadRequestErrorRetries is not None ): - return self.retry_policy.BadRequestErrorRetries + return retry_policy.BadRequestErrorRetries if ( isinstance(exception, litellm.AuthenticationError) - and self.retry_policy.AuthenticationErrorRetries is not None + and retry_policy.AuthenticationErrorRetries is not None ): - return self.retry_policy.AuthenticationErrorRetries + return retry_policy.AuthenticationErrorRetries if ( isinstance(exception, litellm.Timeout) - and self.retry_policy.TimeoutErrorRetries is not None + and retry_policy.TimeoutErrorRetries is not None ): - return self.retry_policy.TimeoutErrorRetries + return retry_policy.TimeoutErrorRetries if ( isinstance(exception, litellm.RateLimitError) - and self.retry_policy.RateLimitErrorRetries is not None + and retry_policy.RateLimitErrorRetries is not None ): - return self.retry_policy.RateLimitErrorRetries + return retry_policy.RateLimitErrorRetries if ( isinstance(exception, litellm.ContentPolicyViolationError) - and self.retry_policy.ContentPolicyViolationErrorRetries is not None + and retry_policy.ContentPolicyViolationErrorRetries is not None ): - return self.retry_policy.ContentPolicyViolationErrorRetries + return retry_policy.ContentPolicyViolationErrorRetries def flush_cache(self): litellm.cache = None