mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #3963 from BerriAI/litellm_set_allowed_fail_policy
[FEAT]- set custom AllowedFailsPolicy on litellm.Router
This commit is contained in:
commit
e149ca73f6
4 changed files with 109 additions and 11 deletions
|
@ -713,26 +713,43 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
|
|||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
#### Retries based on Error Type
|
||||
### [Advanced]: Custom Retries, Cooldowns based on Error Type
|
||||
|
||||
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
|
||||
|
||||
Example:
|
||||
- 4 retries for `ContentPolicyViolationError`
|
||||
- 0 retries for `RateLimitErrors`
|
||||
|
||||
```python
|
||||
retry_policy = RetryPolicy(
|
||||
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||
)
|
||||
|
||||
allowed_fails_policy = AllowedFailsPolicy(
|
||||
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
|
||||
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
|
||||
)
|
||||
```
|
||||
|
||||
Example Usage
|
||||
|
||||
```python
|
||||
from litellm.router import RetryPolicy
|
||||
from litellm.router import RetryPolicy, AllowedFailsPolicy
|
||||
|
||||
retry_policy = RetryPolicy(
|
||||
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||
BadRequestErrorRetries=1,
|
||||
TimeoutErrorRetries=2,
|
||||
RateLimitErrorRetries=3,
|
||||
)
|
||||
|
||||
allowed_fails_policy = AllowedFailsPolicy(
|
||||
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
|
||||
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
|
||||
)
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
|
@ -755,6 +772,7 @@ router = litellm.Router(
|
|||
},
|
||||
],
|
||||
retry_policy=retry_policy,
|
||||
allowed_fails_policy=allowed_fails_policy,
|
||||
)
|
||||
|
||||
response = await router.acompletion(
|
||||
|
|
|
@ -47,6 +47,7 @@ from litellm.types.router import (
|
|||
updateDeployment,
|
||||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AllowedFailsPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
|
@ -116,6 +117,9 @@ class Router:
|
|||
allowed_fails: Optional[
|
||||
int
|
||||
] = None, # Number of times a deployment can failbefore being added to cooldown
|
||||
allowed_fails_policy: Optional[
|
||||
AllowedFailsPolicy
|
||||
] = None, # set custom allowed fails policy
|
||||
cooldown_time: Optional[
|
||||
float
|
||||
] = None, # (seconds) time to cooldown a deployment after failure
|
||||
|
@ -361,6 +365,7 @@ class Router:
|
|||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
||||
model_group_retry_policy
|
||||
)
|
||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
|
||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
@ -2445,6 +2450,7 @@ class Router:
|
|||
deployment_id = _model_info.get("id", None)
|
||||
self._set_cooldown_deployments(
|
||||
exception_status=exception_status,
|
||||
original_exception=exception,
|
||||
deployment=deployment_id,
|
||||
time_to_cooldown=_time_to_cooldown,
|
||||
) # setting deployment_id in cooldown deployments
|
||||
|
@ -2550,6 +2556,7 @@ class Router:
|
|||
|
||||
def _set_cooldown_deployments(
|
||||
self,
|
||||
original_exception: Any,
|
||||
exception_status: Union[str, int],
|
||||
deployment: Optional[str] = None,
|
||||
time_to_cooldown: Optional[float] = None,
|
||||
|
@ -2568,6 +2575,12 @@ class Router:
|
|||
if self._is_cooldown_required(exception_status=exception_status) == False:
|
||||
return
|
||||
|
||||
_allowed_fails = self.get_allowed_fails_from_policy(
|
||||
exception=original_exception,
|
||||
)
|
||||
|
||||
allowed_fails = _allowed_fails or self.allowed_fails
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
# get current fails for deployment
|
||||
|
@ -2577,7 +2590,7 @@ class Router:
|
|||
current_fails = self.failed_calls.get_cache(key=deployment) or 0
|
||||
updated_fails = current_fails + 1
|
||||
verbose_router_logger.debug(
|
||||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
||||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
|
||||
)
|
||||
cooldown_time = self.cooldown_time or 1
|
||||
if time_to_cooldown is not None:
|
||||
|
@ -2594,7 +2607,8 @@ class Router:
|
|||
)
|
||||
exception_status = 500
|
||||
_should_retry = litellm._should_retry(status_code=exception_status)
|
||||
if updated_fails > self.allowed_fails or _should_retry == False:
|
||||
|
||||
if updated_fails > allowed_fails or _should_retry == False:
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
||||
|
@ -2737,6 +2751,7 @@ class Router:
|
|||
except litellm.RateLimitError as e:
|
||||
self._set_cooldown_deployments(
|
||||
exception_status=e.status_code,
|
||||
original_exception=e,
|
||||
deployment=deployment["model_info"]["id"],
|
||||
time_to_cooldown=self.cooldown_time,
|
||||
)
|
||||
|
@ -4429,6 +4444,46 @@ class Router:
|
|||
):
|
||||
return retry_policy.ContentPolicyViolationErrorRetries
|
||||
|
||||
def get_allowed_fails_from_policy(self, exception: Exception):
|
||||
"""
|
||||
BadRequestErrorRetries: Optional[int] = None
|
||||
AuthenticationErrorRetries: Optional[int] = None
|
||||
TimeoutErrorRetries: Optional[int] = None
|
||||
RateLimitErrorRetries: Optional[int] = None
|
||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
|
||||
|
||||
if allowed_fails_policy is None:
|
||||
return None
|
||||
|
||||
if (
|
||||
isinstance(exception, litellm.BadRequestError)
|
||||
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
|
||||
):
|
||||
return allowed_fails_policy.BadRequestErrorAllowedFails
|
||||
if (
|
||||
isinstance(exception, litellm.AuthenticationError)
|
||||
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
|
||||
):
|
||||
return allowed_fails_policy.AuthenticationErrorAllowedFails
|
||||
if (
|
||||
isinstance(exception, litellm.Timeout)
|
||||
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
|
||||
):
|
||||
return allowed_fails_policy.TimeoutErrorAllowedFails
|
||||
if (
|
||||
isinstance(exception, litellm.RateLimitError)
|
||||
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
|
||||
):
|
||||
return allowed_fails_policy.RateLimitErrorAllowedFails
|
||||
if (
|
||||
isinstance(exception, litellm.ContentPolicyViolationError)
|
||||
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
|
||||
):
|
||||
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
|
||||
|
||||
def _initialize_alerting(self):
|
||||
from litellm.integrations.slack_alerting import SlackAlerting
|
||||
|
||||
|
|
|
@ -128,12 +128,17 @@ async def test_router_retries_errors(sync_mode, error_type):
|
|||
["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], #
|
||||
)
|
||||
async def test_router_retry_policy(error_type):
|
||||
from litellm.router import RetryPolicy
|
||||
from litellm.router import RetryPolicy, AllowedFailsPolicy
|
||||
|
||||
retry_policy = RetryPolicy(
|
||||
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
|
||||
)
|
||||
|
||||
allowed_fails_policy = AllowedFailsPolicy(
|
||||
ContentPolicyViolationErrorAllowedFails=1000,
|
||||
RateLimitErrorAllowedFails=100,
|
||||
)
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
|
@ -156,6 +161,7 @@ async def test_router_retry_policy(error_type):
|
|||
},
|
||||
],
|
||||
retry_policy=retry_policy,
|
||||
allowed_fails_policy=allowed_fails_policy,
|
||||
)
|
||||
|
||||
customHandler = MyCustomHandler()
|
||||
|
|
|
@ -76,7 +76,9 @@ class ModelInfo(BaseModel):
|
|||
id: Optional[
|
||||
str
|
||||
] # Allow id to be optional on input, but it will always be present as a str in the model instance
|
||||
db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config.
|
||||
db_model: bool = (
|
||||
False # used for proxy - to separate models which are stored in the db vs. config.
|
||||
)
|
||||
updated_at: Optional[datetime.datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
|
||||
|
@ -381,6 +383,23 @@ class RouterErrors(enum.Enum):
|
|||
no_deployments_available = "No deployments available for selected model"
|
||||
|
||||
|
||||
class AllowedFailsPolicy(BaseModel):
|
||||
"""
|
||||
Use this to set a custom number of allowed fails/minute before cooling down a deployment
|
||||
If `AuthenticationErrorAllowedFails = 1000`, then 1000 AuthenticationError will be allowed before cooling down a deployment
|
||||
|
||||
Mapping of Exception type to allowed_fails for each exception
|
||||
https://docs.litellm.ai/docs/exception_mapping
|
||||
"""
|
||||
|
||||
BadRequestErrorAllowedFails: Optional[int] = None
|
||||
AuthenticationErrorAllowedFails: Optional[int] = None
|
||||
TimeoutErrorAllowedFails: Optional[int] = None
|
||||
RateLimitErrorAllowedFails: Optional[int] = None
|
||||
ContentPolicyViolationErrorAllowedFails: Optional[int] = None
|
||||
InternalServerErrorAllowedFails: Optional[int] = None
|
||||
|
||||
|
||||
class RetryPolicy(BaseModel):
|
||||
"""
|
||||
Use this to set a custom number of retries per exception type
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue