From 2686894823cdfd52c82e7c8c74e64f9d7a4ebc56 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 24 Nov 2023 13:27:44 -0800 Subject: [PATCH] fix(router.py): fix retry logic --- litellm/__init__.py | 1 + litellm/router.py | 54 ++++++++++++++----------- litellm/tests/test_acooldowns_router.py | 2 +- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index e7ad788a01..d08cdf153c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -57,6 +57,7 @@ request_timeout: Optional[float] = 6000 num_retries: Optional[int] = None fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None +allowed_fails: int = 0 ############################################# def get_model_cost_map(url: str): diff --git a/litellm/router.py b/litellm/router.py index bd77f50f75..82199ee725 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -68,6 +68,7 @@ class Router: default_litellm_params = {}, # default params for Router.chat.completion.create set_verbose: bool = False, fallbacks: List = [], + allowed_fails: Optional[int] = None, context_window_fallbacks: List = [], routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: @@ -75,10 +76,11 @@ class Router: self.set_model_list(model_list) self.healthy_deployments: List = self.model_list self.deployment_latency_map = {} - self.cooldown_deployments: dict = {} # {"gpt-3.5-turbo": time.time() when it failed / needed a cooldown} for m in model_list: self.deployment_latency_map[m["litellm_params"]["model"]] = 0 + self.allowed_fails = allowed_fails or litellm.allowed_fails + self.failed_calls = InMemoryCache() # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown self.num_retries = num_retries or litellm.num_retries self.set_verbose = set_verbose self.timeout = timeout or litellm.request_timeout @@ -331,7 +333,6 @@ class Router: return response except Exception as e: for current_attempt in range(num_retries): - num_retries -= 1 # decrement the number of retries self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") try: # if the function call is successful, no exception will be raised and we'll break out of the loop @@ -356,8 +357,7 @@ class Router: pass else: raise e - if self.num_retries == 0: - raise e + raise e def function_with_fallbacks(self, *args, **kwargs): """ @@ -435,8 +435,7 @@ class Router: except Exception as e: self.print_verbose(f"num retries in function with retries: {num_retries}") for current_attempt in range(num_retries): - num_retries -= 1 # decrement the number of retries - self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") + self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}") try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) @@ -458,8 +457,7 @@ class Router: pass else: raise e - if self.num_retries == 0: - raise e + raise e ### HELPER FUNCTIONS @@ -506,27 +504,37 @@ class Router: def _set_cooldown_deployments(self, deployment: str): """ - Add a model to the list of models being cooled down for that minute + Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute """ current_minute = datetime.now().strftime("%H-%M") - # get the current cooldown list for that minute - cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls - cached_value = self.cache.get_cache(key=cooldown_key) + # get current fails for deployment + # update the number of failed calls + # if it's > allowed fails + # cooldown deployment + current_fails = self.failed_calls.get_cache(key=deployment) or 0 + updated_fails = current_fails + 1 + if updated_fails > self.allowed_fails: + # get the current cooldown list for that minute + cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls + cached_value = self.cache.get_cache(key=cooldown_key) - self.print_verbose(f"adding {deployment} to cooldown models") - # update value - try: - if deployment in cached_value: - pass - else: - cached_value = cached_value + [deployment] + self.print_verbose(f"adding {deployment} to cooldown models") + # update value + try: + if deployment in cached_value: + pass + else: + cached_value = cached_value + [deployment] + # save updated value + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60) + except: + cached_value = [deployment] # save updated value self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60) - except: - cached_value = [deployment] - # save updated value - self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60) + else: + self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60) + def _get_cooldown_deployments(self): """ diff --git a/litellm/tests/test_acooldowns_router.py b/litellm/tests/test_acooldowns_router.py index 0c50c686f8..9ea7c3f6cf 100644 --- a/litellm/tests/test_acooldowns_router.py +++ b/litellm/tests/test_acooldowns_router.py @@ -22,7 +22,7 @@ model_list = [{ # list of model deployments "api_base": os.getenv("AZURE_API_BASE") }, "tpm": 240000, - "rpm": 1800 + "rpm": 1800, }, { "model_name": "gpt-3.5-turbo", # openai model name