mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(router.py): fix exponential backoff to use retry-after if present in headers
This commit is contained in:
parent
7f34298ef8
commit
60d6b6bc37
7 changed files with 154 additions and 67 deletions
|
@ -364,7 +364,9 @@ from .utils import (
|
||||||
completion_with_config,
|
completion_with_config,
|
||||||
register_model,
|
register_model,
|
||||||
encode,
|
encode,
|
||||||
decode
|
decode,
|
||||||
|
_calculate_retry_after,
|
||||||
|
_should_retry
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
|
|
@ -395,7 +395,7 @@ class Router:
|
||||||
backoff_factor = 1
|
backoff_factor = 1
|
||||||
original_function = kwargs.pop("original_function")
|
original_function = kwargs.pop("original_function")
|
||||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||||
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
|
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||||
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
|
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
|
||||||
num_retries = kwargs.pop("num_retries")
|
num_retries = kwargs.pop("num_retries")
|
||||||
try:
|
try:
|
||||||
|
@ -404,11 +404,21 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
||||||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||||
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
|
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
|
||||||
raise original_exception
|
raise original_exception
|
||||||
### RETRY
|
### RETRY
|
||||||
|
#### check if it should retry + back-off if required
|
||||||
|
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
|
||||||
|
if hasattr(original_exception.response, "headers"):
|
||||||
|
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
|
||||||
|
else:
|
||||||
|
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
else:
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||||
try:
|
try:
|
||||||
|
@ -417,21 +427,16 @@ class Router:
|
||||||
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
|
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
|
||||||
response = await response
|
response = await response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except openai.RateLimitError as e:
|
except Exception as e:
|
||||||
if num_retries > 0 and fallbacks is None:
|
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
||||||
# on RateLimitError we'll wait for an exponential time before trying again
|
remaining_retries = num_retries - current_attempt
|
||||||
await asyncio.sleep(backoff_factor)
|
if hasattr(e.response, "headers"):
|
||||||
|
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
|
||||||
# increase backoff factor for next run
|
else:
|
||||||
backoff_factor *= 2
|
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries)
|
||||||
else:
|
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
|
||||||
raise e
|
await asyncio.sleep(timeout)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# for any other exception types, immediately retry
|
|
||||||
if num_retries > 0:
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
@ -442,8 +447,8 @@ class Router:
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
model_group = kwargs.get("model")
|
model_group = kwargs.get("model")
|
||||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
|
||||||
try:
|
try:
|
||||||
response = self.function_with_retries(*args, **kwargs)
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
|
@ -507,6 +512,8 @@ class Router:
|
||||||
backoff_factor = 1
|
backoff_factor = 1
|
||||||
original_function = kwargs.pop("original_function")
|
original_function = kwargs.pop("original_function")
|
||||||
num_retries = kwargs.pop("num_retries")
|
num_retries = kwargs.pop("num_retries")
|
||||||
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||||
|
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||||
try:
|
try:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = original_function(*args, **kwargs)
|
response = original_function(*args, **kwargs)
|
||||||
|
@ -514,6 +521,11 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
self.print_verbose(f"num retries in function with retries: {num_retries}")
|
self.print_verbose(f"num retries in function with retries: {num_retries}")
|
||||||
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||||
|
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||||
|
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
|
||||||
|
raise original_exception
|
||||||
|
### RETRY
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
|
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
|
||||||
try:
|
try:
|
||||||
|
@ -523,11 +535,10 @@ class Router:
|
||||||
|
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
if num_retries > 0:
|
if num_retries > 0:
|
||||||
|
remaining_retries = num_retries - current_attempt
|
||||||
|
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
|
||||||
# on RateLimitError we'll wait for an exponential time before trying again
|
# on RateLimitError we'll wait for an exponential time before trying again
|
||||||
time.sleep(backoff_factor)
|
time.sleep(timeout)
|
||||||
|
|
||||||
# increase backoff factor for next run
|
|
||||||
backoff_factor *= 2
|
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -633,7 +644,6 @@ class Router:
|
||||||
else:
|
else:
|
||||||
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60)
|
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60)
|
||||||
|
|
||||||
|
|
||||||
def _get_cooldown_deployments(self):
|
def _get_cooldown_deployments(self):
|
||||||
"""
|
"""
|
||||||
Get the list of models being cooled down for this minute
|
Get the list of models being cooled down for this minute
|
||||||
|
@ -919,7 +929,7 @@ class Router:
|
||||||
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
|
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
|
||||||
|
|
||||||
raise ValueError("No models available.")
|
raise ValueError("No models available.")
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache.flush_cache()
|
self.cache.flush_cache()
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,8 @@ def test_async_response_azure():
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
# test_async_response_azure()
|
||||||
|
|
||||||
def test_async_anyscale_response():
|
def test_async_anyscale_response():
|
||||||
import asyncio
|
import asyncio
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -162,4 +164,4 @@ def test_get_response_non_openai_streaming():
|
||||||
return response
|
return response
|
||||||
asyncio.run(test_async_call())
|
asyncio.run(test_async_call())
|
||||||
|
|
||||||
test_get_response_non_openai_streaming()
|
# test_get_response_non_openai_streaming()
|
|
@ -37,7 +37,13 @@ def test_chat_openai():
|
||||||
}],
|
}],
|
||||||
stream=True,
|
stream=True,
|
||||||
complete_response = True)
|
complete_response = True)
|
||||||
|
response2 = completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm not openai"
|
||||||
|
}],
|
||||||
|
stream=True,
|
||||||
|
complete_response = True)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
assert customHandler.success == True
|
assert customHandler.success == True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -227,10 +227,10 @@ async def asynctest_completion_azure_exception():
|
||||||
print("exception", e)
|
print("exception", e)
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
import asyncio
|
# import asyncio
|
||||||
asyncio.run(
|
# asyncio.run(
|
||||||
asynctest_completion_azure_exception()
|
# asynctest_completion_azure_exception()
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
def test_completion_openai_exception():
|
def test_completion_openai_exception():
|
||||||
|
@ -265,39 +265,40 @@ def test_completion_openai_exception():
|
||||||
|
|
||||||
# test_invalid_request_error(model="command-nightly")
|
# test_invalid_request_error(model="command-nightly")
|
||||||
# Test 3: Rate Limit Errors
|
# Test 3: Rate Limit Errors
|
||||||
# def test_model_call(model):
|
def test_model_call(model):
|
||||||
# try:
|
try:
|
||||||
# sample_text = "how does a court case get to the Supreme Court?"
|
sample_text = "how does a court case get to the Supreme Court?"
|
||||||
# messages = [{ "content": sample_text,"role": "user"}]
|
messages = [{ "content": sample_text,"role": "user"}]
|
||||||
# print(f"model: {model}")
|
print(f"model: {model}")
|
||||||
# response = completion(model=model, messages=messages)
|
response = completion(model=model, messages=messages)
|
||||||
# except RateLimitError:
|
except RateLimitError as e:
|
||||||
# return True
|
print(f"headers: {e.response.headers}")
|
||||||
# # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
return True
|
||||||
# # return True
|
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||||
# except Exception as e:
|
# return True
|
||||||
# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
|
except Exception as e:
|
||||||
# traceback.print_exc()
|
print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
|
||||||
# pass
|
traceback.print_exc()
|
||||||
# return False
|
pass
|
||||||
# # Repeat each model 500 times
|
return False
|
||||||
# # extended_models = [model for model in models for _ in range(250)]
|
# Repeat each model 500 times
|
||||||
# extended_models = ["gpt-3.5-turbo-instruct" for _ in range(250)]
|
# extended_models = [model for model in models for _ in range(250)]
|
||||||
|
extended_models = ["azure/chatgpt-v-2" for _ in range(250)]
|
||||||
|
|
||||||
# def worker(model):
|
def worker(model):
|
||||||
# return test_model_call(model)
|
return test_model_call(model)
|
||||||
|
|
||||||
# # Create a dictionary to store the results
|
# Create a dictionary to store the results
|
||||||
# counts = {True: 0, False: 0}
|
counts = {True: 0, False: 0}
|
||||||
|
|
||||||
# # Use Thread Pool Executor
|
# Use Thread Pool Executor
|
||||||
# with ThreadPoolExecutor(max_workers=500) as executor:
|
with ThreadPoolExecutor(max_workers=500) as executor:
|
||||||
# # Use map to start the operation in thread pool
|
# Use map to start the operation in thread pool
|
||||||
# results = executor.map(worker, extended_models)
|
results = executor.map(worker, extended_models)
|
||||||
|
|
||||||
# # Iterate over results and count True/False
|
# Iterate over results and count True/False
|
||||||
# for result in results:
|
for result in results:
|
||||||
# counts[result] += 1
|
counts[result] += 1
|
||||||
|
|
||||||
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||||
# print(f"accuracy_score: {accuracy_score}")
|
print(f"accuracy_score: {accuracy_score}")
|
||||||
|
|
|
@ -55,7 +55,6 @@
|
||||||
# try:
|
# try:
|
||||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
||||||
# response = await router.acompletion(model="azure-model", messages=messages)
|
# response = await router.acompletion(model="azure-model", messages=messages)
|
||||||
# # response = await litellm.acompletion(model="azure/gpt-35-turbo", messages=messages, api_key="6a0f46e99d554e8caad9c2b7c0ba7319", api_base="https://my-endpoint-canada-berri992.openai.azure.com")
|
|
||||||
# return response
|
# return response
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# print(e, file=sys.stderr)
|
# print(e, file=sys.stderr)
|
||||||
|
@ -64,7 +63,7 @@
|
||||||
|
|
||||||
# async def loadtest_fn():
|
# async def loadtest_fn():
|
||||||
# start = time.time()
|
# start = time.time()
|
||||||
# n = 100
|
# n = 1000
|
||||||
# tasks = [router_completion() for _ in range(n)]
|
# tasks = [router_completion() for _ in range(n)]
|
||||||
# chat_completions = await asyncio.gather(*tasks)
|
# chat_completions = await asyncio.gather(*tasks)
|
||||||
# successful_completions = [c for c in chat_completions if c is not None]
|
# successful_completions = [c for c in chat_completions if c is not None]
|
||||||
|
|
|
@ -2735,7 +2735,6 @@ def json_schema_type(python_type_name: str):
|
||||||
|
|
||||||
return python_to_json_schema_types.get(python_type_name, "string")
|
return python_to_json_schema_types.get(python_type_name, "string")
|
||||||
|
|
||||||
|
|
||||||
def function_to_dict(input_function): # noqa: C901
|
def function_to_dict(input_function): # noqa: C901
|
||||||
"""Using type hints and numpy-styled docstring,
|
"""Using type hints and numpy-styled docstring,
|
||||||
produce a dictionnary usable for OpenAI function calling
|
produce a dictionnary usable for OpenAI function calling
|
||||||
|
@ -3136,7 +3135,7 @@ def set_callbacks(callback_list, function_id=None):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# NOTE: DEPRECATING this in favor of using failure_handler() in Logging:
|
||||||
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
|
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
|
||||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
||||||
try:
|
try:
|
||||||
|
@ -3473,6 +3472,74 @@ def check_valid_key(model: str, api_key: str):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _should_retry(status_code: int):
|
||||||
|
"""
|
||||||
|
Reimplementation of openai's should retry logic, since that one can't be imported.
|
||||||
|
https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639
|
||||||
|
"""
|
||||||
|
# If the server explicitly says whether or not to retry, obey.
|
||||||
|
# Retry on request timeouts.
|
||||||
|
if status_code == 408:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Retry on lock timeouts.
|
||||||
|
if status_code == 409:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Retry on rate limits.
|
||||||
|
if status_code == 429:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Retry internal errors.
|
||||||
|
if status_code >= 500:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None):
|
||||||
|
"""
|
||||||
|
Reimplementation of openai's calculate retry after, since that one can't be imported.
|
||||||
|
https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import email # openai import
|
||||||
|
# About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
||||||
|
#
|
||||||
|
# <http-date>". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for
|
||||||
|
# details.
|
||||||
|
if response_headers is not None:
|
||||||
|
retry_header = response_headers.get("retry-after")
|
||||||
|
try:
|
||||||
|
retry_after = int(retry_header)
|
||||||
|
except Exception:
|
||||||
|
retry_date_tuple = email.utils.parsedate_tz(retry_header)
|
||||||
|
if retry_date_tuple is None:
|
||||||
|
retry_after = -1
|
||||||
|
else:
|
||||||
|
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
||||||
|
retry_after = int(retry_date - time.time())
|
||||||
|
else:
|
||||||
|
retry_after = -1
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
retry_after = -1
|
||||||
|
|
||||||
|
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
||||||
|
if 0 < retry_after <= 60:
|
||||||
|
return retry_after
|
||||||
|
|
||||||
|
initial_retry_delay = 0.5
|
||||||
|
max_retry_delay = 8.0
|
||||||
|
nb_retries = max_retries - remaining_retries
|
||||||
|
|
||||||
|
# Apply exponential backoff, but not more than the max.
|
||||||
|
sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
|
||||||
|
|
||||||
|
# Apply some jitter, plus-or-minus half a second.
|
||||||
|
jitter = 1 - 0.25 * random.random()
|
||||||
|
timeout = sleep_seconds * jitter
|
||||||
|
return timeout if timeout >= 0 else 0
|
||||||
|
|
||||||
# integration helper function
|
# integration helper function
|
||||||
def modify_integration(integration_name, integration_params):
|
def modify_integration(integration_name, integration_params):
|
||||||
global supabaseClient
|
global supabaseClient
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue