feat - set custom AllowedFailsPolicy

This commit is contained in:
Ishaan Jaff 2024-06-01 17:26:21 -07:00
parent 3b94993ddc
commit eb203c051a
2 changed files with 77 additions and 3 deletions

View file

@ -47,6 +47,7 @@ from litellm.types.router import (
updateDeployment,
updateLiteLLMParams,
RetryPolicy,
AllowedFailsPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
@ -113,6 +114,9 @@ class Router:
allowed_fails: Optional[
int
] = None, # Number of times a deployment can failbefore being added to cooldown
allowed_fails_policy: Optional[
AllowedFailsPolicy
] = None, # set custom allowed fails policy
cooldown_time: Optional[
float
] = None, # (seconds) time to cooldown a deployment after failure
@ -355,6 +359,7 @@ class Router:
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
@ -2350,6 +2355,7 @@ class Router:
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
exception_status=exception_status,
original_exception=exception,
deployment=deployment_id,
time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments
@ -2455,6 +2461,7 @@ class Router:
def _set_cooldown_deployments(
self,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
@ -2473,6 +2480,12 @@ class Router:
if self._is_cooldown_required(exception_status=exception_status) == False:
return
_allowed_fails = self.get_allowed_fails_from_policy(
exception=original_exception,
)
allowed_fails = _allowed_fails or self.allowed_fails
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment
@ -2482,7 +2495,7 @@ class Router:
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
verbose_router_logger.debug(
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None:
@ -2499,7 +2512,8 @@ class Router:
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > self.allowed_fails or _should_retry == False:
if updated_fails > allowed_fails or _should_retry == False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
@ -2642,6 +2656,7 @@ class Router:
except litellm.RateLimitError as e:
self._set_cooldown_deployments(
exception_status=e.status_code,
original_exception=e,
deployment=deployment["model_info"]["id"],
time_to_cooldown=self.cooldown_time,
)
@ -4334,6 +4349,46 @@ class Router:
):
return retry_policy.ContentPolicyViolationErrorRetries
def get_allowed_fails_from_policy(self, exception: Exception):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
if allowed_fails_policy is None:
return None
if (
isinstance(exception, litellm.BadRequestError)
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
):
return allowed_fails_policy.BadRequestErrorAllowedFails
if (
isinstance(exception, litellm.AuthenticationError)
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
):
return allowed_fails_policy.AuthenticationErrorAllowedFails
if (
isinstance(exception, litellm.Timeout)
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
):
return allowed_fails_policy.TimeoutErrorAllowedFails
if (
isinstance(exception, litellm.RateLimitError)
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
):
return allowed_fails_policy.RateLimitErrorAllowedFails
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
):
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
def _initialize_alerting(self):
from litellm.integrations.slack_alerting import SlackAlerting