fix(router.py): fix retry logic

This commit is contained in:
Krrish Dholakia 2023-11-24 13:27:44 -08:00
parent 16e1070dbe
commit 2686894823
3 changed files with 33 additions and 24 deletions

View file

@ -57,6 +57,7 @@ request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None
fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0
#############################################
def get_model_cost_map(url: str):

View file

@ -68,6 +68,7 @@ class Router:
default_litellm_params = {}, # default params for Router.chat.completion.create
set_verbose: bool = False,
fallbacks: List = [],
allowed_fails: Optional[int] = None,
context_window_fallbacks: List = [],
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
@ -75,10 +76,11 @@ class Router:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
self.deployment_latency_map = {}
self.cooldown_deployments: dict = {} # {"gpt-3.5-turbo": time.time() when it failed / needed a cooldown}
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
self.allowed_fails = allowed_fails or litellm.allowed_fails
self.failed_calls = InMemoryCache() # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries
self.set_verbose = set_verbose
self.timeout = timeout or litellm.request_timeout
@ -331,7 +333,6 @@ class Router:
return response
except Exception as e:
for current_attempt in range(num_retries):
num_retries -= 1 # decrement the number of retries
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
@ -356,8 +357,7 @@ class Router:
pass
else:
raise e
if self.num_retries == 0:
raise e
raise e
def function_with_fallbacks(self, *args, **kwargs):
"""
@ -435,8 +435,7 @@ class Router:
except Exception as e:
self.print_verbose(f"num retries in function with retries: {num_retries}")
for current_attempt in range(num_retries):
num_retries -= 1 # decrement the number of retries
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
@ -458,8 +457,7 @@ class Router:
pass
else:
raise e
if self.num_retries == 0:
raise e
raise e
### HELPER FUNCTIONS
@ -506,27 +504,37 @@ class Router:
def _set_cooldown_deployments(self,
deployment: str):
"""
Add a model to the list of models being cooled down for that minute
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
"""
current_minute = datetime.now().strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
# get current fails for deployment
# update the number of failed calls
# if it's > allowed fails
# cooldown deployment
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
if updated_fails > self.allowed_fails:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
self.print_verbose(f"adding {deployment} to cooldown models")
# update value
try:
if deployment in cached_value:
pass
else:
cached_value = cached_value + [deployment]
self.print_verbose(f"adding {deployment} to cooldown models")
# update value
try:
if deployment in cached_value:
pass
else:
cached_value = cached_value + [deployment]
# save updated value
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
except:
cached_value = [deployment]
# save updated value
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
except:
cached_value = [deployment]
# save updated value
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
else:
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60)
def _get_cooldown_deployments(self):
"""

View file

@ -22,7 +22,7 @@ model_list = [{ # list of model deployments
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name