Merge pull request #3963 from BerriAI/litellm_set_allowed_fail_policy

[FEAT]- set custom AllowedFailsPolicy on litellm.Router
This commit is contained in:
Ishaan Jaff 2024-06-01 17:57:11 -07:00 committed by GitHub
commit e149ca73f6
4 changed files with 109 additions and 11 deletions

View file

@ -713,18 +713,30 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
```
#### Retries based on Error Type
### [Advanced]: Custom Retries, Cooldowns based on Error Type
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
Example:
- 4 retries for `ContentPolicyViolationError`
- 0 retries for `RateLimitErrors`
```python
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
)
allowed_fails_policy = AllowedFailsPolicy(
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
)
```
Example Usage
```python
from litellm.router import RetryPolicy
from litellm.router import RetryPolicy, AllowedFailsPolicy
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
@ -733,6 +745,11 @@ retry_policy = RetryPolicy(
RateLimitErrorRetries=3,
)
allowed_fails_policy = AllowedFailsPolicy(
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
)
router = litellm.Router(
model_list=[
{
@ -755,6 +772,7 @@ router = litellm.Router(
},
],
retry_policy=retry_policy,
allowed_fails_policy=allowed_fails_policy,
)
response = await router.acompletion(

View file

@ -47,6 +47,7 @@ from litellm.types.router import (
updateDeployment,
updateLiteLLMParams,
RetryPolicy,
AllowedFailsPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
@ -116,6 +117,9 @@ class Router:
allowed_fails: Optional[
int
] = None, # Number of times a deployment can failbefore being added to cooldown
allowed_fails_policy: Optional[
AllowedFailsPolicy
] = None, # set custom allowed fails policy
cooldown_time: Optional[
float
] = None, # (seconds) time to cooldown a deployment after failure
@ -361,6 +365,7 @@ class Router:
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
@ -2445,6 +2450,7 @@ class Router:
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
exception_status=exception_status,
original_exception=exception,
deployment=deployment_id,
time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments
@ -2550,6 +2556,7 @@ class Router:
def _set_cooldown_deployments(
self,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
@ -2568,6 +2575,12 @@ class Router:
if self._is_cooldown_required(exception_status=exception_status) == False:
return
_allowed_fails = self.get_allowed_fails_from_policy(
exception=original_exception,
)
allowed_fails = _allowed_fails or self.allowed_fails
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment
@ -2577,7 +2590,7 @@ class Router:
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
verbose_router_logger.debug(
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None:
@ -2594,7 +2607,8 @@ class Router:
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > self.allowed_fails or _should_retry == False:
if updated_fails > allowed_fails or _should_retry == False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
@ -2737,6 +2751,7 @@ class Router:
except litellm.RateLimitError as e:
self._set_cooldown_deployments(
exception_status=e.status_code,
original_exception=e,
deployment=deployment["model_info"]["id"],
time_to_cooldown=self.cooldown_time,
)
@ -4429,6 +4444,46 @@ class Router:
):
return retry_policy.ContentPolicyViolationErrorRetries
def get_allowed_fails_from_policy(self, exception: Exception):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
if allowed_fails_policy is None:
return None
if (
isinstance(exception, litellm.BadRequestError)
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
):
return allowed_fails_policy.BadRequestErrorAllowedFails
if (
isinstance(exception, litellm.AuthenticationError)
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
):
return allowed_fails_policy.AuthenticationErrorAllowedFails
if (
isinstance(exception, litellm.Timeout)
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
):
return allowed_fails_policy.TimeoutErrorAllowedFails
if (
isinstance(exception, litellm.RateLimitError)
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
):
return allowed_fails_policy.RateLimitErrorAllowedFails
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
):
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
def _initialize_alerting(self):
from litellm.integrations.slack_alerting import SlackAlerting

View file

@ -128,12 +128,17 @@ async def test_router_retries_errors(sync_mode, error_type):
["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], #
)
async def test_router_retry_policy(error_type):
from litellm.router import RetryPolicy
from litellm.router import RetryPolicy, AllowedFailsPolicy
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
)
allowed_fails_policy = AllowedFailsPolicy(
ContentPolicyViolationErrorAllowedFails=1000,
RateLimitErrorAllowedFails=100,
)
router = Router(
model_list=[
{
@ -156,6 +161,7 @@ async def test_router_retry_policy(error_type):
},
],
retry_policy=retry_policy,
allowed_fails_policy=allowed_fails_policy,
)
customHandler = MyCustomHandler()

View file

@ -76,7 +76,9 @@ class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config.
db_model: bool = (
False # used for proxy - to separate models which are stored in the db vs. config.
)
updated_at: Optional[datetime.datetime] = None
updated_by: Optional[str] = None
@ -381,6 +383,23 @@ class RouterErrors(enum.Enum):
no_deployments_available = "No deployments available for selected model"
class AllowedFailsPolicy(BaseModel):
"""
Use this to set a custom number of allowed fails/minute before cooling down a deployment
If `AuthenticationErrorAllowedFails = 1000`, then 1000 AuthenticationError will be allowed before cooling down a deployment
Mapping of Exception type to allowed_fails for each exception
https://docs.litellm.ai/docs/exception_mapping
"""
BadRequestErrorAllowedFails: Optional[int] = None
AuthenticationErrorAllowedFails: Optional[int] = None
TimeoutErrorAllowedFails: Optional[int] = None
RateLimitErrorAllowedFails: Optional[int] = None
ContentPolicyViolationErrorAllowedFails: Optional[int] = None
InternalServerErrorAllowedFails: Optional[int] = None
class RetryPolicy(BaseModel):
"""
Use this to set a custom number of retries per exception type