fix(router.py): fix exponential backoff to use retry-after if present in headers

This commit is contained in:
Krrish Dholakia 2023-11-28 17:24:49 -08:00
parent 0f0ddcc0fb
commit bb1267eb07
7 changed files with 154 additions and 67 deletions

View file

@ -364,7 +364,9 @@ from .utils import (
completion_with_config,
register_model,
encode,
decode
decode,
_calculate_retry_after,
_should_retry
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig

View file

@ -395,7 +395,7 @@ class Router:
backoff_factor = 1
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
num_retries = kwargs.pop("num_retries")
try:
@ -404,11 +404,21 @@ class Router:
return response
except Exception as e:
original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
raise original_exception
### RETRY
#### check if it should retry + back-off if required
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
if hasattr(original_exception.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
else:
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
await asyncio.sleep(timeout)
else:
raise original_exception
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
@ -417,21 +427,16 @@ class Router:
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
response = await response
return response
except openai.RateLimitError as e:
if num_retries > 0 and fallbacks is None:
# on RateLimitError we'll wait for an exponential time before trying again
await asyncio.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
else:
raise e
except Exception as e:
# for any other exception types, immediately retry
if num_retries > 0:
pass
except Exception as e:
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
remaining_retries = num_retries - current_attempt
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
else:
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
await asyncio.sleep(timeout)
else:
raise e
raise original_exception
@ -442,8 +447,8 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group = kwargs.get("model")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
try:
response = self.function_with_retries(*args, **kwargs)
return response
@ -507,6 +512,8 @@ class Router:
backoff_factor = 1
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
@ -514,6 +521,11 @@ class Router:
except Exception as e:
original_exception = e
self.print_verbose(f"num retries in function with retries: {num_retries}")
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
raise original_exception
### RETRY
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
try:
@ -523,11 +535,10 @@ class Router:
except openai.RateLimitError as e:
if num_retries > 0:
remaining_retries = num_retries - current_attempt
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
time.sleep(timeout)
else:
raise e
@ -633,7 +644,6 @@ class Router:
else:
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60)
def _get_cooldown_deployments(self):
"""
Get the list of models being cooled down for this minute
@ -919,7 +929,7 @@ class Router:
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
raise ValueError("No models available.")
def flush_cache(self):
self.cache.flush_cache()

View file

@ -74,6 +74,8 @@ def test_async_response_azure():
asyncio.run(test_get_response())
# test_async_response_azure()
def test_async_anyscale_response():
import asyncio
litellm.set_verbose = True
@ -162,4 +164,4 @@ def test_get_response_non_openai_streaming():
return response
asyncio.run(test_async_call())
test_get_response_non_openai_streaming()
# test_get_response_non_openai_streaming()

View file

@ -37,7 +37,13 @@ def test_chat_openai():
}],
stream=True,
complete_response = True)
response2 = completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm not openai"
}],
stream=True,
complete_response = True)
time.sleep(1)
assert customHandler.success == True
except Exception as e:

View file

@ -227,10 +227,10 @@ async def asynctest_completion_azure_exception():
print("exception", e)
pytest.fail(f"Error occurred: {e}")
import asyncio
asyncio.run(
asynctest_completion_azure_exception()
)
# import asyncio
# asyncio.run(
# asynctest_completion_azure_exception()
# )
def test_completion_openai_exception():
@ -265,39 +265,40 @@ def test_completion_openai_exception():
# test_invalid_request_error(model="command-nightly")
# Test 3: Rate Limit Errors
# def test_model_call(model):
# try:
# sample_text = "how does a court case get to the Supreme Court?"
# messages = [{ "content": sample_text,"role": "user"}]
# print(f"model: {model}")
# response = completion(model=model, messages=messages)
# except RateLimitError:
# return True
# # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# # return True
# except Exception as e:
# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
# traceback.print_exc()
# pass
# return False
# # Repeat each model 500 times
# # extended_models = [model for model in models for _ in range(250)]
# extended_models = ["gpt-3.5-turbo-instruct" for _ in range(250)]
def test_model_call(model):
try:
sample_text = "how does a court case get to the Supreme Court?"
messages = [{ "content": sample_text,"role": "user"}]
print(f"model: {model}")
response = completion(model=model, messages=messages)
except RateLimitError as e:
print(f"headers: {e.response.headers}")
return True
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# return True
except Exception as e:
print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
traceback.print_exc()
pass
return False
# Repeat each model 500 times
# extended_models = [model for model in models for _ in range(250)]
extended_models = ["azure/chatgpt-v-2" for _ in range(250)]
# def worker(model):
# return test_model_call(model)
def worker(model):
return test_model_call(model)
# # Create a dictionary to store the results
# counts = {True: 0, False: 0}
# Create a dictionary to store the results
counts = {True: 0, False: 0}
# # Use Thread Pool Executor
# with ThreadPoolExecutor(max_workers=500) as executor:
# # Use map to start the operation in thread pool
# results = executor.map(worker, extended_models)
# Use Thread Pool Executor
with ThreadPoolExecutor(max_workers=500) as executor:
# Use map to start the operation in thread pool
results = executor.map(worker, extended_models)
# # Iterate over results and count True/False
# for result in results:
# counts[result] += 1
# Iterate over results and count True/False
for result in results:
counts[result] += 1
# accuracy_score = counts[True]/(counts[True] + counts[False])
# print(f"accuracy_score: {accuracy_score}")
accuracy_score = counts[True]/(counts[True] + counts[False])
print(f"accuracy_score: {accuracy_score}")

View file

@ -55,7 +55,6 @@
# try:
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
# response = await router.acompletion(model="azure-model", messages=messages)
# # response = await litellm.acompletion(model="azure/gpt-35-turbo", messages=messages, api_key="6a0f46e99d554e8caad9c2b7c0ba7319", api_base="https://my-endpoint-canada-berri992.openai.azure.com")
# return response
# except Exception as e:
# print(e, file=sys.stderr)
@ -64,7 +63,7 @@
# async def loadtest_fn():
# start = time.time()
# n = 100
# n = 1000
# tasks = [router_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None]

View file

@ -2735,7 +2735,6 @@ def json_schema_type(python_type_name: str):
return python_to_json_schema_types.get(python_type_name, "string")
def function_to_dict(input_function): # noqa: C901
"""Using type hints and numpy-styled docstring,
produce a dictionnary usable for OpenAI function calling
@ -3136,7 +3135,7 @@ def set_callbacks(callback_list, function_id=None):
except Exception as e:
raise e
# NOTE: DEPRECATING this in favor of using failure_handler() in Logging:
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try:
@ -3473,6 +3472,74 @@ def check_valid_key(model: str, api_key: str):
except Exception as e:
return False
def _should_retry(status_code: int):
"""
Reimplementation of openai's should retry logic, since that one can't be imported.
https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639
"""
# If the server explicitly says whether or not to retry, obey.
# Retry on request timeouts.
if status_code == 408:
return True
# Retry on lock timeouts.
if status_code == 409:
return True
# Retry on rate limits.
if status_code == 429:
return True
# Retry internal errors.
if status_code >= 500:
return True
return False
def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None):
"""
Reimplementation of openai's calculate retry after, since that one can't be imported.
https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631
"""
try:
import email # openai import
# About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# <http-date>". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for
# details.
if response_headers is not None:
retry_header = response_headers.get("retry-after")
try:
retry_after = int(retry_header)
except Exception:
retry_date_tuple = email.utils.parsedate_tz(retry_header)
if retry_date_tuple is None:
retry_after = -1
else:
retry_date = email.utils.mktime_tz(retry_date_tuple)
retry_after = int(retry_date - time.time())
else:
retry_after = -1
except Exception:
retry_after = -1
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
if 0 < retry_after <= 60:
return retry_after
initial_retry_delay = 0.5
max_retry_delay = 8.0
nb_retries = max_retries - remaining_retries
# Apply exponential backoff, but not more than the max.
sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random.random()
timeout = sleep_seconds * jitter
return timeout if timeout >= 0 else 0
# integration helper function
def modify_integration(integration_name, integration_params):
global supabaseClient