From eb203c051ad23ef34522fc9d0c17019e67ebfa88 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 1 Jun 2024 17:26:21 -0700 Subject: [PATCH] feat - set custom AllowedFailsPolicy --- litellm/router.py | 59 +++++++++++++++++++++++++++++++++++++++-- litellm/types/router.py | 21 ++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 88eb54a04..2a1e8a122 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -47,6 +47,7 @@ from litellm.types.router import ( updateDeployment, updateLiteLLMParams, RetryPolicy, + AllowedFailsPolicy, AlertingConfig, DeploymentTypedDict, ModelGroupInfo, @@ -113,6 +114,9 @@ class Router: allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown + allowed_fails_policy: Optional[ + AllowedFailsPolicy + ] = None, # set custom allowed fails policy cooldown_time: Optional[ float ] = None, # (seconds) time to cooldown a deployment after failure @@ -355,6 +359,7 @@ class Router: self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( model_group_retry_policy ) + self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy self.alerting_config: Optional[AlertingConfig] = alerting_config if self.alerting_config is not None: self._initialize_alerting() @@ -2350,6 +2355,7 @@ class Router: deployment_id = _model_info.get("id", None) self._set_cooldown_deployments( exception_status=exception_status, + original_exception=exception, deployment=deployment_id, time_to_cooldown=_time_to_cooldown, ) # setting deployment_id in cooldown deployments @@ -2455,6 +2461,7 @@ class Router: def _set_cooldown_deployments( self, + original_exception: Any, exception_status: Union[str, int], deployment: Optional[str] = None, time_to_cooldown: Optional[float] = None, @@ -2473,6 +2480,12 @@ class Router: if self._is_cooldown_required(exception_status=exception_status) == False: return + _allowed_fails = self.get_allowed_fails_from_policy( + exception=original_exception, + ) + + allowed_fails = _allowed_fails or self.allowed_fails + dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") # get current fails for deployment @@ -2482,7 +2495,7 @@ class Router: current_fails = self.failed_calls.get_cache(key=deployment) or 0 updated_fails = current_fails + 1 verbose_router_logger.debug( - f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" + f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}" ) cooldown_time = self.cooldown_time or 1 if time_to_cooldown is not None: @@ -2499,7 +2512,8 @@ class Router: ) exception_status = 500 _should_retry = litellm._should_retry(status_code=exception_status) - if updated_fails > self.allowed_fails or _should_retry == False: + + if updated_fails > allowed_fails or _should_retry == False: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cached_value = self.cache.get_cache(key=cooldown_key) @@ -2642,6 +2656,7 @@ class Router: except litellm.RateLimitError as e: self._set_cooldown_deployments( exception_status=e.status_code, + original_exception=e, deployment=deployment["model_info"]["id"], time_to_cooldown=self.cooldown_time, ) @@ -4334,6 +4349,46 @@ class Router: ): return retry_policy.ContentPolicyViolationErrorRetries + def get_allowed_fails_from_policy(self, exception: Exception): + """ + BadRequestErrorRetries: Optional[int] = None + AuthenticationErrorRetries: Optional[int] = None + TimeoutErrorRetries: Optional[int] = None + RateLimitErrorRetries: Optional[int] = None + ContentPolicyViolationErrorRetries: Optional[int] = None + """ + # if we can find the exception then in the retry policy -> return the number of retries + allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy + + if allowed_fails_policy is None: + return None + + if ( + isinstance(exception, litellm.BadRequestError) + and allowed_fails_policy.BadRequestErrorAllowedFails is not None + ): + return allowed_fails_policy.BadRequestErrorAllowedFails + if ( + isinstance(exception, litellm.AuthenticationError) + and allowed_fails_policy.AuthenticationErrorAllowedFails is not None + ): + return allowed_fails_policy.AuthenticationErrorAllowedFails + if ( + isinstance(exception, litellm.Timeout) + and allowed_fails_policy.TimeoutErrorAllowedFails is not None + ): + return allowed_fails_policy.TimeoutErrorAllowedFails + if ( + isinstance(exception, litellm.RateLimitError) + and allowed_fails_policy.RateLimitErrorAllowedFails is not None + ): + return allowed_fails_policy.RateLimitErrorAllowedFails + if ( + isinstance(exception, litellm.ContentPolicyViolationError) + and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None + ): + return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails + def _initialize_alerting(self): from litellm.integrations.slack_alerting import SlackAlerting diff --git a/litellm/types/router.py b/litellm/types/router.py index 8fed461cb..4a1f4498c 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -76,7 +76,9 @@ class ModelInfo(BaseModel): id: Optional[ str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. + db_model: bool = ( + False # used for proxy - to separate models which are stored in the db vs. config. + ) updated_at: Optional[datetime.datetime] = None updated_by: Optional[str] = None @@ -381,6 +383,23 @@ class RouterErrors(enum.Enum): no_deployments_available = "No deployments available for selected model" +class AllowedFailsPolicy(BaseModel): + """ + Use this to set a custom number of allowed_fails for each exception type before cooling down a deployment + If RateLimitErrorRetries = 3, then 3 retries will be made for RateLimitError + + Mapping of Exception type to allowed_fails for each exception + https://docs.litellm.ai/docs/exception_mapping + """ + + BadRequestErrorAllowedFails: Optional[int] = None + AuthenticationErrorAllowedFails: Optional[int] = None + TimeoutErrorAllowedFails: Optional[int] = None + RateLimitErrorAllowedFails: Optional[int] = None + ContentPolicyViolationErrorAllowedFails: Optional[int] = None + InternalServerErrorAllowedFails: Optional[int] = None + + class RetryPolicy(BaseModel): """ Use this to set a custom number of retries per exception type