fix(router.py): fix exponential backoff to use retry-after if present in headers

This commit is contained in:
Krrish Dholakia 2023-11-28 17:24:49 -08:00
parent 0f0ddcc0fb
commit bb1267eb07
7 changed files with 154 additions and 67 deletions

View file

@ -395,7 +395,7 @@ class Router:
backoff_factor = 1
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
num_retries = kwargs.pop("num_retries")
try:
@ -404,11 +404,21 @@ class Router:
return response
except Exception as e:
original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
raise original_exception
### RETRY
#### check if it should retry + back-off if required
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
if hasattr(original_exception.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
else:
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
await asyncio.sleep(timeout)
else:
raise original_exception
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
@ -417,21 +427,16 @@ class Router:
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
response = await response
return response
except openai.RateLimitError as e:
if num_retries > 0 and fallbacks is None:
# on RateLimitError we'll wait for an exponential time before trying again
await asyncio.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
else:
raise e
except Exception as e:
# for any other exception types, immediately retry
if num_retries > 0:
pass
except Exception as e:
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
remaining_retries = num_retries - current_attempt
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
else:
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
await asyncio.sleep(timeout)
else:
raise e
raise original_exception
@ -442,8 +447,8 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group = kwargs.get("model")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
try:
response = self.function_with_retries(*args, **kwargs)
return response
@ -507,6 +512,8 @@ class Router:
backoff_factor = 1
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
@ -514,6 +521,11 @@ class Router:
except Exception as e:
original_exception = e
self.print_verbose(f"num retries in function with retries: {num_retries}")
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
raise original_exception
### RETRY
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
try:
@ -523,11 +535,10 @@ class Router:
except openai.RateLimitError as e:
if num_retries > 0:
remaining_retries = num_retries - current_attempt
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
time.sleep(timeout)
else:
raise e
@ -633,7 +644,6 @@ class Router:
else:
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60)
def _get_cooldown_deployments(self):
"""
Get the list of models being cooled down for this minute
@ -919,7 +929,7 @@ class Router:
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
raise ValueError("No models available.")
def flush_cache(self):
self.cache.flush_cache()