mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
fix(router.py): fix exponential backoff to use retry-after if present in headers
This commit is contained in:
parent
0f0ddcc0fb
commit
bb1267eb07
7 changed files with 154 additions and 67 deletions
|
@ -364,7 +364,9 @@ from .utils import (
|
|||
completion_with_config,
|
||||
register_model,
|
||||
encode,
|
||||
decode
|
||||
decode,
|
||||
_calculate_retry_after,
|
||||
_should_retry
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
|
@ -395,7 +395,7 @@ class Router:
|
|||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
|
||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
try:
|
||||
|
@ -404,11 +404,21 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
||||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
|
||||
raise original_exception
|
||||
### RETRY
|
||||
#### check if it should retry + back-off if required
|
||||
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
|
||||
if hasattr(original_exception.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
|
||||
else:
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
|
||||
await asyncio.sleep(timeout)
|
||||
else:
|
||||
raise original_exception
|
||||
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||
try:
|
||||
|
@ -417,21 +427,16 @@ class Router:
|
|||
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
|
||||
response = await response
|
||||
return response
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
if num_retries > 0 and fallbacks is None:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
await asyncio.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
else:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, immediately retry
|
||||
if num_retries > 0:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
||||
remaining_retries = num_retries - current_attempt
|
||||
if hasattr(e.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
|
||||
else:
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
|
||||
await asyncio.sleep(timeout)
|
||||
else:
|
||||
raise e
|
||||
raise original_exception
|
||||
|
@ -442,8 +447,8 @@ class Router:
|
|||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
|
||||
try:
|
||||
response = self.function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
|
@ -507,6 +512,8 @@ class Router:
|
|||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
|
@ -514,6 +521,11 @@ class Router:
|
|||
except Exception as e:
|
||||
original_exception = e
|
||||
self.print_verbose(f"num retries in function with retries: {num_retries}")
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
|
||||
raise original_exception
|
||||
### RETRY
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
|
||||
try:
|
||||
|
@ -523,11 +535,10 @@ class Router:
|
|||
|
||||
except openai.RateLimitError as e:
|
||||
if num_retries > 0:
|
||||
remaining_retries = num_retries - current_attempt
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
time.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
time.sleep(timeout)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
@ -633,7 +644,6 @@ class Router:
|
|||
else:
|
||||
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60)
|
||||
|
||||
|
||||
def _get_cooldown_deployments(self):
|
||||
"""
|
||||
Get the list of models being cooled down for this minute
|
||||
|
@ -919,7 +929,7 @@ class Router:
|
|||
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
|
||||
|
||||
raise ValueError("No models available.")
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache.flush_cache()
|
||||
|
||||
|
|
|
@ -74,6 +74,8 @@ def test_async_response_azure():
|
|||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
# test_async_response_azure()
|
||||
|
||||
def test_async_anyscale_response():
|
||||
import asyncio
|
||||
litellm.set_verbose = True
|
||||
|
@ -162,4 +164,4 @@ def test_get_response_non_openai_streaming():
|
|||
return response
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
test_get_response_non_openai_streaming()
|
||||
# test_get_response_non_openai_streaming()
|
|
@ -37,7 +37,13 @@ def test_chat_openai():
|
|||
}],
|
||||
stream=True,
|
||||
complete_response = True)
|
||||
|
||||
response2 = completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm not openai"
|
||||
}],
|
||||
stream=True,
|
||||
complete_response = True)
|
||||
time.sleep(1)
|
||||
assert customHandler.success == True
|
||||
except Exception as e:
|
||||
|
|
|
@ -227,10 +227,10 @@ async def asynctest_completion_azure_exception():
|
|||
print("exception", e)
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
import asyncio
|
||||
asyncio.run(
|
||||
asynctest_completion_azure_exception()
|
||||
)
|
||||
# import asyncio
|
||||
# asyncio.run(
|
||||
# asynctest_completion_azure_exception()
|
||||
# )
|
||||
|
||||
|
||||
def test_completion_openai_exception():
|
||||
|
@ -265,39 +265,40 @@ def test_completion_openai_exception():
|
|||
|
||||
# test_invalid_request_error(model="command-nightly")
|
||||
# Test 3: Rate Limit Errors
|
||||
# def test_model_call(model):
|
||||
# try:
|
||||
# sample_text = "how does a court case get to the Supreme Court?"
|
||||
# messages = [{ "content": sample_text,"role": "user"}]
|
||||
# print(f"model: {model}")
|
||||
# response = completion(model=model, messages=messages)
|
||||
# except RateLimitError:
|
||||
# return True
|
||||
# # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||
# # return True
|
||||
# except Exception as e:
|
||||
# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
|
||||
# traceback.print_exc()
|
||||
# pass
|
||||
# return False
|
||||
# # Repeat each model 500 times
|
||||
# # extended_models = [model for model in models for _ in range(250)]
|
||||
# extended_models = ["gpt-3.5-turbo-instruct" for _ in range(250)]
|
||||
def test_model_call(model):
|
||||
try:
|
||||
sample_text = "how does a court case get to the Supreme Court?"
|
||||
messages = [{ "content": sample_text,"role": "user"}]
|
||||
print(f"model: {model}")
|
||||
response = completion(model=model, messages=messages)
|
||||
except RateLimitError as e:
|
||||
print(f"headers: {e.response.headers}")
|
||||
return True
|
||||
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||
# return True
|
||||
except Exception as e:
|
||||
print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
return False
|
||||
# Repeat each model 500 times
|
||||
# extended_models = [model for model in models for _ in range(250)]
|
||||
extended_models = ["azure/chatgpt-v-2" for _ in range(250)]
|
||||
|
||||
# def worker(model):
|
||||
# return test_model_call(model)
|
||||
def worker(model):
|
||||
return test_model_call(model)
|
||||
|
||||
# # Create a dictionary to store the results
|
||||
# counts = {True: 0, False: 0}
|
||||
# Create a dictionary to store the results
|
||||
counts = {True: 0, False: 0}
|
||||
|
||||
# # Use Thread Pool Executor
|
||||
# with ThreadPoolExecutor(max_workers=500) as executor:
|
||||
# # Use map to start the operation in thread pool
|
||||
# results = executor.map(worker, extended_models)
|
||||
# Use Thread Pool Executor
|
||||
with ThreadPoolExecutor(max_workers=500) as executor:
|
||||
# Use map to start the operation in thread pool
|
||||
results = executor.map(worker, extended_models)
|
||||
|
||||
# # Iterate over results and count True/False
|
||||
# for result in results:
|
||||
# counts[result] += 1
|
||||
# Iterate over results and count True/False
|
||||
for result in results:
|
||||
counts[result] += 1
|
||||
|
||||
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||
# print(f"accuracy_score: {accuracy_score}")
|
||||
accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||
print(f"accuracy_score: {accuracy_score}")
|
||||
|
|
|
@ -55,7 +55,6 @@
|
|||
# try:
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
||||
# response = await router.acompletion(model="azure-model", messages=messages)
|
||||
# # response = await litellm.acompletion(model="azure/gpt-35-turbo", messages=messages, api_key="6a0f46e99d554e8caad9c2b7c0ba7319", api_base="https://my-endpoint-canada-berri992.openai.azure.com")
|
||||
# return response
|
||||
# except Exception as e:
|
||||
# print(e, file=sys.stderr)
|
||||
|
@ -64,7 +63,7 @@
|
|||
|
||||
# async def loadtest_fn():
|
||||
# start = time.time()
|
||||
# n = 100
|
||||
# n = 1000
|
||||
# tasks = [router_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
|
|
|
@ -2735,7 +2735,6 @@ def json_schema_type(python_type_name: str):
|
|||
|
||||
return python_to_json_schema_types.get(python_type_name, "string")
|
||||
|
||||
|
||||
def function_to_dict(input_function): # noqa: C901
|
||||
"""Using type hints and numpy-styled docstring,
|
||||
produce a dictionnary usable for OpenAI function calling
|
||||
|
@ -3136,7 +3135,7 @@ def set_callbacks(callback_list, function_id=None):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# NOTE: DEPRECATING this in favor of using failure_handler() in Logging:
|
||||
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
||||
try:
|
||||
|
@ -3473,6 +3472,74 @@ def check_valid_key(model: str, api_key: str):
|
|||
except Exception as e:
|
||||
return False
|
||||
|
||||
def _should_retry(status_code: int):
|
||||
"""
|
||||
Reimplementation of openai's should retry logic, since that one can't be imported.
|
||||
https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639
|
||||
"""
|
||||
# If the server explicitly says whether or not to retry, obey.
|
||||
# Retry on request timeouts.
|
||||
if status_code == 408:
|
||||
return True
|
||||
|
||||
# Retry on lock timeouts.
|
||||
if status_code == 409:
|
||||
return True
|
||||
|
||||
# Retry on rate limits.
|
||||
if status_code == 429:
|
||||
return True
|
||||
|
||||
# Retry internal errors.
|
||||
if status_code >= 500:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None):
|
||||
"""
|
||||
Reimplementation of openai's calculate retry after, since that one can't be imported.
|
||||
https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631
|
||||
"""
|
||||
try:
|
||||
import email # openai import
|
||||
# About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
||||
#
|
||||
# <http-date>". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for
|
||||
# details.
|
||||
if response_headers is not None:
|
||||
retry_header = response_headers.get("retry-after")
|
||||
try:
|
||||
retry_after = int(retry_header)
|
||||
except Exception:
|
||||
retry_date_tuple = email.utils.parsedate_tz(retry_header)
|
||||
if retry_date_tuple is None:
|
||||
retry_after = -1
|
||||
else:
|
||||
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
||||
retry_after = int(retry_date - time.time())
|
||||
else:
|
||||
retry_after = -1
|
||||
|
||||
except Exception:
|
||||
retry_after = -1
|
||||
|
||||
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
||||
if 0 < retry_after <= 60:
|
||||
return retry_after
|
||||
|
||||
initial_retry_delay = 0.5
|
||||
max_retry_delay = 8.0
|
||||
nb_retries = max_retries - remaining_retries
|
||||
|
||||
# Apply exponential backoff, but not more than the max.
|
||||
sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
|
||||
|
||||
# Apply some jitter, plus-or-minus half a second.
|
||||
jitter = 1 - 0.25 * random.random()
|
||||
timeout = sleep_seconds * jitter
|
||||
return timeout if timeout >= 0 else 0
|
||||
|
||||
# integration helper function
|
||||
def modify_integration(integration_name, integration_params):
|
||||
global supabaseClient
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue