mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #3963 from BerriAI/litellm_set_allowed_fail_policy
[FEAT]- set custom AllowedFailsPolicy on litellm.Router
This commit is contained in:
commit
e149ca73f6
4 changed files with 109 additions and 11 deletions
|
@ -713,26 +713,43 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Retries based on Error Type
|
### [Advanced]: Custom Retries, Cooldowns based on Error Type
|
||||||
|
|
||||||
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||||
|
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
- 4 retries for `ContentPolicyViolationError`
|
|
||||||
- 0 retries for `RateLimitErrors`
|
```python
|
||||||
|
retry_policy = RetryPolicy(
|
||||||
|
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||||
|
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_fails_policy = AllowedFailsPolicy(
|
||||||
|
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
|
||||||
|
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Example Usage
|
Example Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm.router import RetryPolicy
|
from litellm.router import RetryPolicy, AllowedFailsPolicy
|
||||||
|
|
||||||
retry_policy = RetryPolicy(
|
retry_policy = RetryPolicy(
|
||||||
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||||
BadRequestErrorRetries=1,
|
BadRequestErrorRetries=1,
|
||||||
TimeoutErrorRetries=2,
|
TimeoutErrorRetries=2,
|
||||||
RateLimitErrorRetries=3,
|
RateLimitErrorRetries=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
allowed_fails_policy = AllowedFailsPolicy(
|
||||||
|
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
|
||||||
|
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
|
||||||
|
)
|
||||||
|
|
||||||
router = litellm.Router(
|
router = litellm.Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
|
@ -755,6 +772,7 @@ router = litellm.Router(
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
retry_policy=retry_policy,
|
retry_policy=retry_policy,
|
||||||
|
allowed_fails_policy=allowed_fails_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
|
|
|
@ -47,6 +47,7 @@ from litellm.types.router import (
|
||||||
updateDeployment,
|
updateDeployment,
|
||||||
updateLiteLLMParams,
|
updateLiteLLMParams,
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
|
AllowedFailsPolicy,
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
ModelGroupInfo,
|
ModelGroupInfo,
|
||||||
|
@ -116,6 +117,9 @@ class Router:
|
||||||
allowed_fails: Optional[
|
allowed_fails: Optional[
|
||||||
int
|
int
|
||||||
] = None, # Number of times a deployment can failbefore being added to cooldown
|
] = None, # Number of times a deployment can failbefore being added to cooldown
|
||||||
|
allowed_fails_policy: Optional[
|
||||||
|
AllowedFailsPolicy
|
||||||
|
] = None, # set custom allowed fails policy
|
||||||
cooldown_time: Optional[
|
cooldown_time: Optional[
|
||||||
float
|
float
|
||||||
] = None, # (seconds) time to cooldown a deployment after failure
|
] = None, # (seconds) time to cooldown a deployment after failure
|
||||||
|
@ -361,6 +365,7 @@ class Router:
|
||||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
||||||
model_group_retry_policy
|
model_group_retry_policy
|
||||||
)
|
)
|
||||||
|
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
|
||||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||||
if self.alerting_config is not None:
|
if self.alerting_config is not None:
|
||||||
self._initialize_alerting()
|
self._initialize_alerting()
|
||||||
|
@ -2445,6 +2450,7 @@ class Router:
|
||||||
deployment_id = _model_info.get("id", None)
|
deployment_id = _model_info.get("id", None)
|
||||||
self._set_cooldown_deployments(
|
self._set_cooldown_deployments(
|
||||||
exception_status=exception_status,
|
exception_status=exception_status,
|
||||||
|
original_exception=exception,
|
||||||
deployment=deployment_id,
|
deployment=deployment_id,
|
||||||
time_to_cooldown=_time_to_cooldown,
|
time_to_cooldown=_time_to_cooldown,
|
||||||
) # setting deployment_id in cooldown deployments
|
) # setting deployment_id in cooldown deployments
|
||||||
|
@ -2550,6 +2556,7 @@ class Router:
|
||||||
|
|
||||||
def _set_cooldown_deployments(
|
def _set_cooldown_deployments(
|
||||||
self,
|
self,
|
||||||
|
original_exception: Any,
|
||||||
exception_status: Union[str, int],
|
exception_status: Union[str, int],
|
||||||
deployment: Optional[str] = None,
|
deployment: Optional[str] = None,
|
||||||
time_to_cooldown: Optional[float] = None,
|
time_to_cooldown: Optional[float] = None,
|
||||||
|
@ -2568,6 +2575,12 @@ class Router:
|
||||||
if self._is_cooldown_required(exception_status=exception_status) == False:
|
if self._is_cooldown_required(exception_status=exception_status) == False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
_allowed_fails = self.get_allowed_fails_from_policy(
|
||||||
|
exception=original_exception,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_fails = _allowed_fails or self.allowed_fails
|
||||||
|
|
||||||
dt = get_utc_datetime()
|
dt = get_utc_datetime()
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime("%H-%M")
|
||||||
# get current fails for deployment
|
# get current fails for deployment
|
||||||
|
@ -2577,7 +2590,7 @@ class Router:
|
||||||
current_fails = self.failed_calls.get_cache(key=deployment) or 0
|
current_fails = self.failed_calls.get_cache(key=deployment) or 0
|
||||||
updated_fails = current_fails + 1
|
updated_fails = current_fails + 1
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
|
||||||
)
|
)
|
||||||
cooldown_time = self.cooldown_time or 1
|
cooldown_time = self.cooldown_time or 1
|
||||||
if time_to_cooldown is not None:
|
if time_to_cooldown is not None:
|
||||||
|
@ -2594,7 +2607,8 @@ class Router:
|
||||||
)
|
)
|
||||||
exception_status = 500
|
exception_status = 500
|
||||||
_should_retry = litellm._should_retry(status_code=exception_status)
|
_should_retry = litellm._should_retry(status_code=exception_status)
|
||||||
if updated_fails > self.allowed_fails or _should_retry == False:
|
|
||||||
|
if updated_fails > allowed_fails or _should_retry == False:
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
cached_value = self.cache.get_cache(key=cooldown_key)
|
||||||
|
@ -2737,6 +2751,7 @@ class Router:
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
self._set_cooldown_deployments(
|
self._set_cooldown_deployments(
|
||||||
exception_status=e.status_code,
|
exception_status=e.status_code,
|
||||||
|
original_exception=e,
|
||||||
deployment=deployment["model_info"]["id"],
|
deployment=deployment["model_info"]["id"],
|
||||||
time_to_cooldown=self.cooldown_time,
|
time_to_cooldown=self.cooldown_time,
|
||||||
)
|
)
|
||||||
|
@ -4429,6 +4444,46 @@ class Router:
|
||||||
):
|
):
|
||||||
return retry_policy.ContentPolicyViolationErrorRetries
|
return retry_policy.ContentPolicyViolationErrorRetries
|
||||||
|
|
||||||
|
def get_allowed_fails_from_policy(self, exception: Exception):
|
||||||
|
"""
|
||||||
|
BadRequestErrorRetries: Optional[int] = None
|
||||||
|
AuthenticationErrorRetries: Optional[int] = None
|
||||||
|
TimeoutErrorRetries: Optional[int] = None
|
||||||
|
RateLimitErrorRetries: Optional[int] = None
|
||||||
|
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||||
|
"""
|
||||||
|
# if we can find the exception then in the retry policy -> return the number of retries
|
||||||
|
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
|
||||||
|
|
||||||
|
if allowed_fails_policy is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(exception, litellm.BadRequestError)
|
||||||
|
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
|
||||||
|
):
|
||||||
|
return allowed_fails_policy.BadRequestErrorAllowedFails
|
||||||
|
if (
|
||||||
|
isinstance(exception, litellm.AuthenticationError)
|
||||||
|
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
|
||||||
|
):
|
||||||
|
return allowed_fails_policy.AuthenticationErrorAllowedFails
|
||||||
|
if (
|
||||||
|
isinstance(exception, litellm.Timeout)
|
||||||
|
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
|
||||||
|
):
|
||||||
|
return allowed_fails_policy.TimeoutErrorAllowedFails
|
||||||
|
if (
|
||||||
|
isinstance(exception, litellm.RateLimitError)
|
||||||
|
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
|
||||||
|
):
|
||||||
|
return allowed_fails_policy.RateLimitErrorAllowedFails
|
||||||
|
if (
|
||||||
|
isinstance(exception, litellm.ContentPolicyViolationError)
|
||||||
|
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
|
||||||
|
):
|
||||||
|
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
|
||||||
|
|
||||||
def _initialize_alerting(self):
|
def _initialize_alerting(self):
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
|
|
||||||
|
|
|
@ -128,12 +128,17 @@ async def test_router_retries_errors(sync_mode, error_type):
|
||||||
["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], #
|
["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], #
|
||||||
)
|
)
|
||||||
async def test_router_retry_policy(error_type):
|
async def test_router_retry_policy(error_type):
|
||||||
from litellm.router import RetryPolicy
|
from litellm.router import RetryPolicy, AllowedFailsPolicy
|
||||||
|
|
||||||
retry_policy = RetryPolicy(
|
retry_policy = RetryPolicy(
|
||||||
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
|
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
allowed_fails_policy = AllowedFailsPolicy(
|
||||||
|
ContentPolicyViolationErrorAllowedFails=1000,
|
||||||
|
RateLimitErrorAllowedFails=100,
|
||||||
|
)
|
||||||
|
|
||||||
router = Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
|
@ -156,6 +161,7 @@ async def test_router_retry_policy(error_type):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
retry_policy=retry_policy,
|
retry_policy=retry_policy,
|
||||||
|
allowed_fails_policy=allowed_fails_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
|
|
|
@ -76,7 +76,9 @@ class ModelInfo(BaseModel):
|
||||||
id: Optional[
|
id: Optional[
|
||||||
str
|
str
|
||||||
] # Allow id to be optional on input, but it will always be present as a str in the model instance
|
] # Allow id to be optional on input, but it will always be present as a str in the model instance
|
||||||
db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config.
|
db_model: bool = (
|
||||||
|
False # used for proxy - to separate models which are stored in the db vs. config.
|
||||||
|
)
|
||||||
updated_at: Optional[datetime.datetime] = None
|
updated_at: Optional[datetime.datetime] = None
|
||||||
updated_by: Optional[str] = None
|
updated_by: Optional[str] = None
|
||||||
|
|
||||||
|
@ -381,6 +383,23 @@ class RouterErrors(enum.Enum):
|
||||||
no_deployments_available = "No deployments available for selected model"
|
no_deployments_available = "No deployments available for selected model"
|
||||||
|
|
||||||
|
|
||||||
|
class AllowedFailsPolicy(BaseModel):
|
||||||
|
"""
|
||||||
|
Use this to set a custom number of allowed fails/minute before cooling down a deployment
|
||||||
|
If `AuthenticationErrorAllowedFails = 1000`, then 1000 AuthenticationError will be allowed before cooling down a deployment
|
||||||
|
|
||||||
|
Mapping of Exception type to allowed_fails for each exception
|
||||||
|
https://docs.litellm.ai/docs/exception_mapping
|
||||||
|
"""
|
||||||
|
|
||||||
|
BadRequestErrorAllowedFails: Optional[int] = None
|
||||||
|
AuthenticationErrorAllowedFails: Optional[int] = None
|
||||||
|
TimeoutErrorAllowedFails: Optional[int] = None
|
||||||
|
RateLimitErrorAllowedFails: Optional[int] = None
|
||||||
|
ContentPolicyViolationErrorAllowedFails: Optional[int] = None
|
||||||
|
InternalServerErrorAllowedFails: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class RetryPolicy(BaseModel):
|
class RetryPolicy(BaseModel):
|
||||||
"""
|
"""
|
||||||
Use this to set a custom number of retries per exception type
|
Use this to set a custom number of retries per exception type
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue