From 5d17c814a3a072fc059b84db99bca8cc9f3821b3 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 4 May 2024 17:04:51 -0700 Subject: [PATCH] router - use retry policy --- litellm/router.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index d64deecec..55342b40b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -42,6 +42,7 @@ from litellm.types.router import ( RouterErrors, updateDeployment, updateLiteLLMParams, + RetryPolicy, ) from litellm.integrations.custom_logger import CustomLogger @@ -82,6 +83,9 @@ class Router: model_group_alias: Optional[dict] = {}, enable_pre_call_checks: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request + retry_policy: Optional[ + RetryPolicy + ] = None, # set custom retries for different exceptions allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown @@ -303,6 +307,7 @@ class Router: f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}" ) # noqa self.routing_strategy_args = routing_strategy_args + self.retry_policy: Optional[RetryPolicy] = retry_policy def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): if routing_strategy == "least-busy": @@ -1504,6 +1509,14 @@ class Router: ) await asyncio.sleep(_timeout) ## LOGGING + if self.retry_policy is not None: + # get num_retries from retry policy + _retry_policy_retries = self.get_num_retries_from_retry_policy( + exception=original_exception + ) + if _retry_policy_retries is not None: + num_retries = _retry_policy_retries + if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) @@ -3254,6 +3267,43 @@ class Router: except Exception as e: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") + def get_num_retries_from_retry_policy(self, exception: Exception): + """ + BadRequestErrorRetries: Optional[int] = None + AuthenticationErrorRetries: Optional[int] = None + TimeoutErrorRetries: Optional[int] = None + RateLimitErrorRetries: Optional[int] = None + ContentPolicyViolationErrorRetries: Optional[int] = None + """ + # if we can find the exception then in the retry policy -> return the number of retries + if self.retry_policy is None: + return None + if ( + isinstance(exception, litellm.BadRequestError) + and self.retry_policy.BadRequestErrorRetries is not None + ): + return self.retry_policy.BadRequestErrorRetries + if ( + isinstance(exception, litellm.AuthenticationError) + and self.retry_policy.AuthenticationErrorRetries is not None + ): + return self.retry_policy.AuthenticationErrorRetries + if ( + isinstance(exception, litellm.Timeout) + and self.retry_policy.TimeoutErrorRetries is not None + ): + return self.retry_policy.TimeoutErrorRetries + if ( + isinstance(exception, litellm.RateLimitError) + and self.retry_policy.RateLimitErrorRetries is not None + ): + return self.retry_policy.RateLimitErrorRetries + if ( + isinstance(exception, litellm.ContentPolicyViolationError) + and self.retry_policy.ContentPolicyViolationErrorRetries is not None + ): + return self.retry_policy.ContentPolicyViolationErrorRetries + def flush_cache(self): litellm.cache = None self.cache.flush_cache()