feat - set custom AllowedFailsPolicy

This commit is contained in:
Ishaan Jaff 2024-06-01 17:26:21 -07:00
parent 3b94993ddc
commit eb203c051a
2 changed files with 77 additions and 3 deletions

View file

@ -47,6 +47,7 @@ from litellm.types.router import (
updateDeployment, updateDeployment,
updateLiteLLMParams, updateLiteLLMParams,
RetryPolicy, RetryPolicy,
AllowedFailsPolicy,
AlertingConfig, AlertingConfig,
DeploymentTypedDict, DeploymentTypedDict,
ModelGroupInfo, ModelGroupInfo,
@ -113,6 +114,9 @@ class Router:
allowed_fails: Optional[ allowed_fails: Optional[
int int
] = None, # Number of times a deployment can failbefore being added to cooldown ] = None, # Number of times a deployment can failbefore being added to cooldown
allowed_fails_policy: Optional[
AllowedFailsPolicy
] = None, # set custom allowed fails policy
cooldown_time: Optional[ cooldown_time: Optional[
float float
] = None, # (seconds) time to cooldown a deployment after failure ] = None, # (seconds) time to cooldown a deployment after failure
@ -355,6 +359,7 @@ class Router:
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy model_group_retry_policy
) )
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.alerting_config: Optional[AlertingConfig] = alerting_config self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None: if self.alerting_config is not None:
self._initialize_alerting() self._initialize_alerting()
@ -2350,6 +2355,7 @@ class Router:
deployment_id = _model_info.get("id", None) deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments( self._set_cooldown_deployments(
exception_status=exception_status, exception_status=exception_status,
original_exception=exception,
deployment=deployment_id, deployment=deployment_id,
time_to_cooldown=_time_to_cooldown, time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments ) # setting deployment_id in cooldown deployments
@ -2455,6 +2461,7 @@ class Router:
def _set_cooldown_deployments( def _set_cooldown_deployments(
self, self,
original_exception: Any,
exception_status: Union[str, int], exception_status: Union[str, int],
deployment: Optional[str] = None, deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None, time_to_cooldown: Optional[float] = None,
@ -2473,6 +2480,12 @@ class Router:
if self._is_cooldown_required(exception_status=exception_status) == False: if self._is_cooldown_required(exception_status=exception_status) == False:
return return
_allowed_fails = self.get_allowed_fails_from_policy(
exception=original_exception,
)
allowed_fails = _allowed_fails or self.allowed_fails
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
# get current fails for deployment # get current fails for deployment
@ -2482,7 +2495,7 @@ class Router:
current_fails = self.failed_calls.get_cache(key=deployment) or 0 current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1 updated_fails = current_fails + 1
verbose_router_logger.debug( verbose_router_logger.debug(
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
) )
cooldown_time = self.cooldown_time or 1 cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None: if time_to_cooldown is not None:
@ -2499,7 +2512,8 @@ class Router:
) )
exception_status = 500 exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status) _should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > self.allowed_fails or _should_retry == False:
if updated_fails > allowed_fails or _should_retry == False:
# get the current cooldown list for that minute # get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key) cached_value = self.cache.get_cache(key=cooldown_key)
@ -2642,6 +2656,7 @@ class Router:
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
self._set_cooldown_deployments( self._set_cooldown_deployments(
exception_status=e.status_code, exception_status=e.status_code,
original_exception=e,
deployment=deployment["model_info"]["id"], deployment=deployment["model_info"]["id"],
time_to_cooldown=self.cooldown_time, time_to_cooldown=self.cooldown_time,
) )
@ -4334,6 +4349,46 @@ class Router:
): ):
return retry_policy.ContentPolicyViolationErrorRetries return retry_policy.ContentPolicyViolationErrorRetries
def get_allowed_fails_from_policy(self, exception: Exception):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
if allowed_fails_policy is None:
return None
if (
isinstance(exception, litellm.BadRequestError)
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
):
return allowed_fails_policy.BadRequestErrorAllowedFails
if (
isinstance(exception, litellm.AuthenticationError)
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
):
return allowed_fails_policy.AuthenticationErrorAllowedFails
if (
isinstance(exception, litellm.Timeout)
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
):
return allowed_fails_policy.TimeoutErrorAllowedFails
if (
isinstance(exception, litellm.RateLimitError)
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
):
return allowed_fails_policy.RateLimitErrorAllowedFails
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
):
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
def _initialize_alerting(self): def _initialize_alerting(self):
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting

View file

@ -76,7 +76,9 @@ class ModelInfo(BaseModel):
id: Optional[ id: Optional[
str str
] # Allow id to be optional on input, but it will always be present as a str in the model instance ] # Allow id to be optional on input, but it will always be present as a str in the model instance
db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. db_model: bool = (
False # used for proxy - to separate models which are stored in the db vs. config.
)
updated_at: Optional[datetime.datetime] = None updated_at: Optional[datetime.datetime] = None
updated_by: Optional[str] = None updated_by: Optional[str] = None
@ -381,6 +383,23 @@ class RouterErrors(enum.Enum):
no_deployments_available = "No deployments available for selected model" no_deployments_available = "No deployments available for selected model"
class AllowedFailsPolicy(BaseModel):
"""
Use this to set a custom number of allowed_fails for each exception type before cooling down a deployment
If RateLimitErrorRetries = 3, then 3 retries will be made for RateLimitError
Mapping of Exception type to allowed_fails for each exception
https://docs.litellm.ai/docs/exception_mapping
"""
BadRequestErrorAllowedFails: Optional[int] = None
AuthenticationErrorAllowedFails: Optional[int] = None
TimeoutErrorAllowedFails: Optional[int] = None
RateLimitErrorAllowedFails: Optional[int] = None
ContentPolicyViolationErrorAllowedFails: Optional[int] = None
InternalServerErrorAllowedFails: Optional[int] = None
class RetryPolicy(BaseModel): class RetryPolicy(BaseModel):
""" """
Use this to set a custom number of retries per exception type Use this to set a custom number of retries per exception type