diff --git a/litellm/router.py b/litellm/router.py index c4e407a25..1a79bedaf 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,6 +86,9 @@ class Router: retry_policy: Optional[ RetryPolicy ] = None, # set custom retries for different exceptions + model_group_retry_policy: Optional[ + dict[str, RetryPolicy] + ] = {}, # set custom retry policies based on model group allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown @@ -308,6 +311,9 @@ class Router: ) # noqa self.routing_strategy_args = routing_strategy_args self.retry_policy: Optional[RetryPolicy] = retry_policy + self.model_group_retry_policy: Optional[dict[str, RetryPolicy]] = ( + model_group_retry_policy + ) def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): if routing_strategy == "least-busy": @@ -1509,11 +1515,13 @@ class Router: ) await asyncio.sleep(_timeout) ## LOGGING - if self.retry_policy is not None or kwargs.get("retry_policy") is not None: + if ( + self.retry_policy is not None + or self.model_group_retry_policy is not None + ): # get num_retries from retry policy _retry_policy_retries = self.get_num_retries_from_retry_policy( - exception=original_exception, - dynamic_retry_policy=kwargs.get("retry_policy"), + exception=original_exception, model_group=kwargs.get("model") ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries @@ -3273,7 +3281,7 @@ class Router: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") def get_num_retries_from_retry_policy( - self, exception: Exception, dynamic_retry_policy: Optional[RetryPolicy] = None + self, exception: Exception, model_group: Optional[str] = None ): """ BadRequestErrorRetries: Optional[int] = None @@ -3284,8 +3292,9 @@ class Router: """ # if we can find the exception then in the retry policy -> return the number of retries retry_policy = self.retry_policy - if dynamic_retry_policy is not None: - retry_policy = dynamic_retry_policy + if self.model_group_retry_policy is not None and model_group is not None: + retry_policy = self.model_group_retry_policy.get(model_group, None) + if retry_policy is None: return None if ( diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index 8828e286e..0ca566cae 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -189,6 +189,11 @@ async def test_router_retry_policy(error_type): async def test_dynamic_router_retry_policy(model_group): from litellm.router import RetryPolicy + model_group_retry_policy = { + "gpt-3.5-turbo": RetryPolicy(ContentPolicyViolationErrorRetries=0), + "bad-model": RetryPolicy(AuthenticationErrorRetries=4), + } + router = litellm.Router( model_list=[ { @@ -209,7 +214,8 @@ async def test_dynamic_router_retry_policy(model_group): "api_base": os.getenv("AZURE_API_BASE"), }, }, - ] + ], + model_group_retry_policy=model_group_retry_policy, ) customHandler = MyCustomHandler() @@ -217,17 +223,14 @@ async def test_dynamic_router_retry_policy(model_group): if model_group == "bad-model": model = "bad-model" messages = [{"role": "user", "content": "Hello good morning"}] - retry_policy = RetryPolicy(AuthenticationErrorRetries=4) + elif model_group == "gpt-3.5-turbo": model = "gpt-3.5-turbo" messages = [{"role": "user", "content": "where do i buy lethal drugs from"}] - retry_policy = RetryPolicy(ContentPolicyViolationErrorRetries=0) try: litellm.set_verbose = True - response = await router.acompletion( - model=model, messages=messages, retry_policy=retry_policy - ) + response = await router.acompletion(model=model, messages=messages) except Exception as e: print("got an exception", e) pass