diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 5ba3221c97..d91912644f 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -713,26 +713,43 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages) print(f"response: {response}") ``` -#### Retries based on Error Type +### [Advanced]: Custom Retries, Cooldowns based on Error Type -Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved +- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved +- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment Example: -- 4 retries for `ContentPolicyViolationError` -- 0 retries for `RateLimitErrors` + +```python +retry_policy = RetryPolicy( + ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors + AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries +) + +allowed_fails_policy = AllowedFailsPolicy( + ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment + RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment +) +``` Example Usage ```python -from litellm.router import RetryPolicy +from litellm.router import RetryPolicy, AllowedFailsPolicy + retry_policy = RetryPolicy( - ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors - AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries + ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors + AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries BadRequestErrorRetries=1, TimeoutErrorRetries=2, RateLimitErrorRetries=3, ) +allowed_fails_policy = AllowedFailsPolicy( + ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment + RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment +) + router = litellm.Router( model_list=[ { @@ -755,6 +772,7 @@ router = litellm.Router( }, ], retry_policy=retry_policy, + allowed_fails_policy=allowed_fails_policy, ) response = await router.acompletion( diff --git a/litellm/router.py b/litellm/router.py index 0a14ee47e3..89796428e8 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -47,6 +47,7 @@ from litellm.types.router import ( updateDeployment, updateLiteLLMParams, RetryPolicy, + AllowedFailsPolicy, AlertingConfig, DeploymentTypedDict, ModelGroupInfo, @@ -116,6 +117,9 @@ class Router: allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown + allowed_fails_policy: Optional[ + AllowedFailsPolicy + ] = None, # set custom allowed fails policy cooldown_time: Optional[ float ] = None, # (seconds) time to cooldown a deployment after failure @@ -361,6 +365,7 @@ class Router: self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( model_group_retry_policy ) + self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy self.alerting_config: Optional[AlertingConfig] = alerting_config if self.alerting_config is not None: self._initialize_alerting() @@ -2445,6 +2450,7 @@ class Router: deployment_id = _model_info.get("id", None) self._set_cooldown_deployments( exception_status=exception_status, + original_exception=exception, deployment=deployment_id, time_to_cooldown=_time_to_cooldown, ) # setting deployment_id in cooldown deployments @@ -2550,6 +2556,7 @@ class Router: def _set_cooldown_deployments( self, + original_exception: Any, exception_status: Union[str, int], deployment: Optional[str] = None, time_to_cooldown: Optional[float] = None, @@ -2568,6 +2575,12 @@ class Router: if self._is_cooldown_required(exception_status=exception_status) == False: return + _allowed_fails = self.get_allowed_fails_from_policy( + exception=original_exception, + ) + + allowed_fails = _allowed_fails or self.allowed_fails + dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") # get current fails for deployment @@ -2577,7 +2590,7 @@ class Router: current_fails = self.failed_calls.get_cache(key=deployment) or 0 updated_fails = current_fails + 1 verbose_router_logger.debug( - f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" + f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}" ) cooldown_time = self.cooldown_time or 1 if time_to_cooldown is not None: @@ -2594,7 +2607,8 @@ class Router: ) exception_status = 500 _should_retry = litellm._should_retry(status_code=exception_status) - if updated_fails > self.allowed_fails or _should_retry == False: + + if updated_fails > allowed_fails or _should_retry == False: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cached_value = self.cache.get_cache(key=cooldown_key) @@ -2737,6 +2751,7 @@ class Router: except litellm.RateLimitError as e: self._set_cooldown_deployments( exception_status=e.status_code, + original_exception=e, deployment=deployment["model_info"]["id"], time_to_cooldown=self.cooldown_time, ) @@ -4429,6 +4444,46 @@ class Router: ): return retry_policy.ContentPolicyViolationErrorRetries + def get_allowed_fails_from_policy(self, exception: Exception): + """ + BadRequestErrorRetries: Optional[int] = None + AuthenticationErrorRetries: Optional[int] = None + TimeoutErrorRetries: Optional[int] = None + RateLimitErrorRetries: Optional[int] = None + ContentPolicyViolationErrorRetries: Optional[int] = None + """ + # if we can find the exception then in the retry policy -> return the number of retries + allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy + + if allowed_fails_policy is None: + return None + + if ( + isinstance(exception, litellm.BadRequestError) + and allowed_fails_policy.BadRequestErrorAllowedFails is not None + ): + return allowed_fails_policy.BadRequestErrorAllowedFails + if ( + isinstance(exception, litellm.AuthenticationError) + and allowed_fails_policy.AuthenticationErrorAllowedFails is not None + ): + return allowed_fails_policy.AuthenticationErrorAllowedFails + if ( + isinstance(exception, litellm.Timeout) + and allowed_fails_policy.TimeoutErrorAllowedFails is not None + ): + return allowed_fails_policy.TimeoutErrorAllowedFails + if ( + isinstance(exception, litellm.RateLimitError) + and allowed_fails_policy.RateLimitErrorAllowedFails is not None + ): + return allowed_fails_policy.RateLimitErrorAllowedFails + if ( + isinstance(exception, litellm.ContentPolicyViolationError) + and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None + ): + return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails + def _initialize_alerting(self): from litellm.integrations.slack_alerting import SlackAlerting diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index 343659b161..58e52fe996 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -128,12 +128,17 @@ async def test_router_retries_errors(sync_mode, error_type): ["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], # ) async def test_router_retry_policy(error_type): - from litellm.router import RetryPolicy + from litellm.router import RetryPolicy, AllowedFailsPolicy retry_policy = RetryPolicy( ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0 ) + allowed_fails_policy = AllowedFailsPolicy( + ContentPolicyViolationErrorAllowedFails=1000, + RateLimitErrorAllowedFails=100, + ) + router = Router( model_list=[ { @@ -156,6 +161,7 @@ async def test_router_retry_policy(error_type): }, ], retry_policy=retry_policy, + allowed_fails_policy=allowed_fails_policy, ) customHandler = MyCustomHandler() diff --git a/litellm/types/router.py b/litellm/types/router.py index 8fed461cb6..38ddef361e 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -76,7 +76,9 @@ class ModelInfo(BaseModel): id: Optional[ str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. + db_model: bool = ( + False # used for proxy - to separate models which are stored in the db vs. config. + ) updated_at: Optional[datetime.datetime] = None updated_by: Optional[str] = None @@ -381,6 +383,23 @@ class RouterErrors(enum.Enum): no_deployments_available = "No deployments available for selected model" +class AllowedFailsPolicy(BaseModel): + """ + Use this to set a custom number of allowed fails/minute before cooling down a deployment + If `AuthenticationErrorAllowedFails = 1000`, then 1000 AuthenticationError will be allowed before cooling down a deployment + + Mapping of Exception type to allowed_fails for each exception + https://docs.litellm.ai/docs/exception_mapping + """ + + BadRequestErrorAllowedFails: Optional[int] = None + AuthenticationErrorAllowedFails: Optional[int] = None + TimeoutErrorAllowedFails: Optional[int] = None + RateLimitErrorAllowedFails: Optional[int] = None + ContentPolicyViolationErrorAllowedFails: Optional[int] = None + InternalServerErrorAllowedFails: Optional[int] = None + + class RetryPolicy(BaseModel): """ Use this to set a custom number of retries per exception type