From 25241de69e571bf9159cdf67a02b10a0dc7306c6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 4 Jan 2024 22:23:51 +0530 Subject: [PATCH] fix(router.py): don't retry malformed / content policy violating errors (400 status code) https://github.com/BerriAI/litellm/issues/1317 , https://github.com/BerriAI/litellm/issues/1316 --- litellm/router.py | 16 +- litellm/tests/test_router_policy_violation.py | 137 ++++++++++++++++++ 2 files changed, 147 insertions(+), 6 deletions(-) create mode 100644 litellm/tests/test_router_policy_violation.py diff --git a/litellm/router.py b/litellm/router.py index e222a9336..770098df0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -773,6 +773,10 @@ class Router: ) original_exception = e try: + if ( + hasattr(e, "status_code") and e.status_code == 400 + ): # don't retry a malformed request + raise e self.print_verbose(f"Trying to fallback b/w models") if ( isinstance(e, litellm.ContextWindowExceededError) @@ -846,7 +850,7 @@ class Router: return response except Exception as e: original_exception = e - ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available + ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error if ( isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None @@ -864,12 +868,12 @@ class Router: min_timeout=self.retry_after, ) await asyncio.sleep(timeout) - elif ( - hasattr(original_exception, "status_code") - and hasattr(original_exception, "response") - and litellm._should_retry(status_code=original_exception.status_code) + elif hasattr(original_exception, "status_code") and litellm._should_retry( + status_code=original_exception.status_code ): - if hasattr(original_exception.response, "headers"): + if hasattr(original_exception, "response") and hasattr( + original_exception.response, "headers" + ): timeout = litellm._calculate_retry_after( remaining_retries=num_retries, max_retries=num_retries, diff --git a/litellm/tests/test_router_policy_violation.py b/litellm/tests/test_router_policy_violation.py new file mode 100644 index 000000000..52f50eb59 --- /dev/null +++ b/litellm/tests/test_router_policy_violation.py @@ -0,0 +1,137 @@ +#### What this tests #### +# This tests if the router sends back a policy violation, without retries + +import sys, os, time +import traceback, asyncio +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import litellm +from litellm import Router +from litellm.integrations.custom_logger import CustomLogger + + +class MyCustomHandler(CustomLogger): + success: bool = False + failure: bool = False + previous_models: int = 0 + + def log_pre_api_call(self, model, messages, kwargs): + print(f"Pre-API Call") + print( + f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" + ) + self.previous_models += len( + kwargs["litellm_params"]["metadata"]["previous_models"] + ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} + print(f"self.previous_models: {self.previous_models}") + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + print( + f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}" + ) + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Failure") + + +kwargs = { + "model": "azure/gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "vorrei vedere la cosa più bella ad Ercolano. Qual’è?", + }, + ], +} + + +@pytest.mark.asyncio +async def test_async_fallbacks(): + litellm.set_verbose = False + model_list = [ + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, + ] + + router = Router( + model_list=model_list, + num_retries=3, + fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], + # context_window_fallbacks=[ + # {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, + # {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}, + # ], + set_verbose=False, + ) + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + try: + response = await router.acompletion(**kwargs) + pytest.fail( + f"An exception occurred: {e}" + ) # should've raised azure policy error + except litellm.Timeout as e: + pass + except Exception as e: + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 0 # 0 retries, 0 fallback + router.reset() + finally: + router.reset()