From bb1267eb07c2adaeca62eb525a53a886f749ddef Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 28 Nov 2023 17:24:49 -0800 Subject: [PATCH] fix(router.py): fix exponential backoff to use retry-after if present in headers --- litellm/__init__.py | 4 +- litellm/router.py | 60 +++++++++++++--------- litellm/tests/test_async_fn.py | 4 +- litellm/tests/test_custom_logger.py | 8 ++- litellm/tests/test_exceptions.py | 71 +++++++++++++------------- litellm/tests/test_profiling_router.py | 3 +- litellm/utils.py | 71 +++++++++++++++++++++++++- 7 files changed, 154 insertions(+), 67 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index aeaa97fc3a..121acda724 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -364,7 +364,9 @@ from .utils import ( completion_with_config, register_model, encode, - decode + decode, + _calculate_retry_after, + _should_retry ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/router.py b/litellm/router.py index c4648ee3ee..c268325ee7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -395,7 +395,7 @@ class Router: backoff_factor = 1 original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks) + context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) self.print_verbose(f"async function w/ retries: original_function - {original_function}") num_retries = kwargs.pop("num_retries") try: @@ -404,11 +404,21 @@ class Router: return response except Exception as e: original_exception = e - ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR + ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)): raise original_exception ### RETRY + #### check if it should retry + back-off if required + if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code): + if hasattr(original_exception.response, "headers"): + timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers) + else: + timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries) + await asyncio.sleep(timeout) + else: + raise original_exception + for current_attempt in range(num_retries): self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") try: @@ -417,21 +427,16 @@ class Router: if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines response = await response return response - - except openai.RateLimitError as e: - if num_retries > 0 and fallbacks is None: - # on RateLimitError we'll wait for an exponential time before trying again - await asyncio.sleep(backoff_factor) - - # increase backoff factor for next run - backoff_factor *= 2 - else: - raise e - - except Exception as e: - # for any other exception types, immediately retry - if num_retries > 0: - pass + + except Exception as e: + if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code): + remaining_retries = num_retries - current_attempt + if hasattr(e.response, "headers"): + timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers) + else: + timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries) + timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries) + await asyncio.sleep(timeout) else: raise e raise original_exception @@ -442,8 +447,8 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") - fallbacks = kwargs.pop("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) + fallbacks = kwargs.get("fallbacks", self.fallbacks) + context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks) try: response = self.function_with_retries(*args, **kwargs) return response @@ -507,6 +512,8 @@ class Router: backoff_factor = 1 original_function = kwargs.pop("original_function") num_retries = kwargs.pop("num_retries") + fallbacks = kwargs.pop("fallbacks", self.fallbacks) + context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) @@ -514,6 +521,11 @@ class Router: except Exception as e: original_exception = e self.print_verbose(f"num retries in function with retries: {num_retries}") + ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR + if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) + or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)): + raise original_exception + ### RETRY for current_attempt in range(num_retries): self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}") try: @@ -523,11 +535,10 @@ class Router: except openai.RateLimitError as e: if num_retries > 0: + remaining_retries = num_retries - current_attempt + timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries) # on RateLimitError we'll wait for an exponential time before trying again - time.sleep(backoff_factor) - - # increase backoff factor for next run - backoff_factor *= 2 + time.sleep(timeout) else: raise e @@ -633,7 +644,6 @@ class Router: else: self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=60) - def _get_cooldown_deployments(self): """ Get the list of models being cooled down for this minute @@ -919,7 +929,7 @@ class Router: return self.get_usage_based_available_deployment(model=model, messages=messages, input=input) raise ValueError("No models available.") - + def flush_cache(self): self.cache.flush_cache() diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 69bd97019c..05446f9f46 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -74,6 +74,8 @@ def test_async_response_azure(): asyncio.run(test_get_response()) +# test_async_response_azure() + def test_async_anyscale_response(): import asyncio litellm.set_verbose = True @@ -162,4 +164,4 @@ def test_get_response_non_openai_streaming(): return response asyncio.run(test_async_call()) -test_get_response_non_openai_streaming() \ No newline at end of file +# test_get_response_non_openai_streaming() \ No newline at end of file diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index b6c0894c44..1527423ca8 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -37,7 +37,13 @@ def test_chat_openai(): }], stream=True, complete_response = True) - + response2 = completion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm not openai" + }], + stream=True, + complete_response = True) time.sleep(1) assert customHandler.success == True except Exception as e: diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index e473af80ad..9cfecedd43 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -227,10 +227,10 @@ async def asynctest_completion_azure_exception(): print("exception", e) pytest.fail(f"Error occurred: {e}") -import asyncio -asyncio.run( - asynctest_completion_azure_exception() -) +# import asyncio +# asyncio.run( +# asynctest_completion_azure_exception() +# ) def test_completion_openai_exception(): @@ -265,39 +265,40 @@ def test_completion_openai_exception(): # test_invalid_request_error(model="command-nightly") # Test 3: Rate Limit Errors -# def test_model_call(model): -# try: -# sample_text = "how does a court case get to the Supreme Court?" -# messages = [{ "content": sample_text,"role": "user"}] -# print(f"model: {model}") -# response = completion(model=model, messages=messages) -# except RateLimitError: -# return True -# # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server -# # return True -# except Exception as e: -# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}") -# traceback.print_exc() -# pass -# return False -# # Repeat each model 500 times -# # extended_models = [model for model in models for _ in range(250)] -# extended_models = ["gpt-3.5-turbo-instruct" for _ in range(250)] +def test_model_call(model): + try: + sample_text = "how does a court case get to the Supreme Court?" + messages = [{ "content": sample_text,"role": "user"}] + print(f"model: {model}") + response = completion(model=model, messages=messages) + except RateLimitError as e: + print(f"headers: {e.response.headers}") + return True + # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server + # return True + except Exception as e: + print(f"Uncaught Exception {model}: {type(e).__name__} - {e}") + traceback.print_exc() + pass + return False +# Repeat each model 500 times +# extended_models = [model for model in models for _ in range(250)] +extended_models = ["azure/chatgpt-v-2" for _ in range(250)] -# def worker(model): -# return test_model_call(model) +def worker(model): + return test_model_call(model) -# # Create a dictionary to store the results -# counts = {True: 0, False: 0} +# Create a dictionary to store the results +counts = {True: 0, False: 0} -# # Use Thread Pool Executor -# with ThreadPoolExecutor(max_workers=500) as executor: -# # Use map to start the operation in thread pool -# results = executor.map(worker, extended_models) +# Use Thread Pool Executor +with ThreadPoolExecutor(max_workers=500) as executor: + # Use map to start the operation in thread pool + results = executor.map(worker, extended_models) -# # Iterate over results and count True/False -# for result in results: -# counts[result] += 1 + # Iterate over results and count True/False + for result in results: + counts[result] += 1 -# accuracy_score = counts[True]/(counts[True] + counts[False]) -# print(f"accuracy_score: {accuracy_score}") +accuracy_score = counts[True]/(counts[True] + counts[False]) +print(f"accuracy_score: {accuracy_score}") diff --git a/litellm/tests/test_profiling_router.py b/litellm/tests/test_profiling_router.py index 383d9dd692..a27a3eafaf 100644 --- a/litellm/tests/test_profiling_router.py +++ b/litellm/tests/test_profiling_router.py @@ -55,7 +55,6 @@ # try: # messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}] # response = await router.acompletion(model="azure-model", messages=messages) -# # response = await litellm.acompletion(model="azure/gpt-35-turbo", messages=messages, api_key="6a0f46e99d554e8caad9c2b7c0ba7319", api_base="https://my-endpoint-canada-berri992.openai.azure.com") # return response # except Exception as e: # print(e, file=sys.stderr) @@ -64,7 +63,7 @@ # async def loadtest_fn(): # start = time.time() -# n = 100 +# n = 1000 # tasks = [router_completion() for _ in range(n)] # chat_completions = await asyncio.gather(*tasks) # successful_completions = [c for c in chat_completions if c is not None] diff --git a/litellm/utils.py b/litellm/utils.py index 4b092fd636..b56890cd2c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2735,7 +2735,6 @@ def json_schema_type(python_type_name: str): return python_to_json_schema_types.get(python_type_name, "string") - def function_to_dict(input_function): # noqa: C901 """Using type hints and numpy-styled docstring, produce a dictionnary usable for OpenAI function calling @@ -3136,7 +3135,7 @@ def set_callbacks(callback_list, function_id=None): except Exception as e: raise e - +# NOTE: DEPRECATING this in favor of using failure_handler() in Logging: def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger try: @@ -3473,6 +3472,74 @@ def check_valid_key(model: str, api_key: str): except Exception as e: return False +def _should_retry(status_code: int): + """ + Reimplementation of openai's should retry logic, since that one can't be imported. + https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639 + """ + # If the server explicitly says whether or not to retry, obey. + # Retry on request timeouts. + if status_code == 408: + return True + + # Retry on lock timeouts. + if status_code == 409: + return True + + # Retry on rate limits. + if status_code == 429: + return True + + # Retry internal errors. + if status_code >= 500: + return True + + return False + +def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None): + """ + Reimplementation of openai's calculate retry after, since that one can't be imported. + https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631 + """ + try: + import email # openai import + # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + # + # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for + # details. + if response_headers is not None: + retry_header = response_headers.get("retry-after") + try: + retry_after = int(retry_header) + except Exception: + retry_date_tuple = email.utils.parsedate_tz(retry_header) + if retry_date_tuple is None: + retry_after = -1 + else: + retry_date = email.utils.mktime_tz(retry_date_tuple) + retry_after = int(retry_date - time.time()) + else: + retry_after = -1 + + except Exception: + retry_after = -1 + + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. + if 0 < retry_after <= 60: + return retry_after + + initial_retry_delay = 0.5 + max_retry_delay = 8.0 + nb_retries = max_retries - remaining_retries + + # Apply exponential backoff, but not more than the max. + sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay) + + # Apply some jitter, plus-or-minus half a second. + jitter = 1 - 0.25 * random.random() + timeout = sleep_seconds * jitter + return timeout if timeout >= 0 else 0 + # integration helper function def modify_integration(integration_name, integration_params): global supabaseClient