fix(router.py): don't retry malformed / content policy violating errors (400 status code)

https://github.com/BerriAI/litellm/issues/1317 , https://github.com/BerriAI/litellm/issues/1316
This commit is contained in:
Krrish Dholakia 2024-01-04 22:23:51 +05:30
parent aa72d65c90
commit 25241de69e
2 changed files with 147 additions and 6 deletions

View file

@ -773,6 +773,10 @@ class Router:
)
original_exception = e
try:
if (
hasattr(e, "status_code") and e.status_code == 400
): # don't retry a malformed request
raise e
self.print_verbose(f"Trying to fallback b/w models")
if (
isinstance(e, litellm.ContextWindowExceededError)
@ -846,7 +850,7 @@ class Router:
return response
except Exception as e:
original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(original_exception, litellm.ContextWindowExceededError)
and context_window_fallbacks is None
@ -864,12 +868,12 @@ class Router:
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif (
hasattr(original_exception, "status_code")
and hasattr(original_exception, "response")
and litellm._should_retry(status_code=original_exception.status_code)
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
if hasattr(original_exception, "response") and hasattr(
original_exception.response, "headers"
):
if hasattr(original_exception.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,

View file

@ -0,0 +1,137 @@
#### What this tests ####
# This tests if the router sends back a policy violation, without retries
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger):
success: bool = False
failure: bool = False
previous_models: int = 0
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
print(
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
)
self.previous_models += len(
kwargs["litellm_params"]["metadata"]["previous_models"]
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
)
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
kwargs = {
"model": "azure/gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "vorrei vedere la cosa più bella ad Ercolano. Qualè?",
},
],
}
@pytest.mark.asyncio
async def test_async_fallbacks():
litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
]
router = Router(
model_list=model_list,
num_retries=3,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
# context_window_fallbacks=[
# {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
# {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
# ],
set_verbose=False,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
try:
response = await router.acompletion(**kwargs)
pytest.fail(
f"An exception occurred: {e}"
) # should've raised azure policy error
except litellm.Timeout as e:
pass
except Exception as e:
await asyncio.sleep(
0.05
) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 0 # 0 retries, 0 fallback
router.reset()
finally:
router.reset()