forked from phoenix/litellm-mirror
fix(router.py): don't retry malformed / content policy violating errors (400 status code)
https://github.com/BerriAI/litellm/issues/1317 , https://github.com/BerriAI/litellm/issues/1316
This commit is contained in:
parent
aa72d65c90
commit
25241de69e
2 changed files with 147 additions and 6 deletions
|
@ -773,6 +773,10 @@ class Router:
|
|||
)
|
||||
original_exception = e
|
||||
try:
|
||||
if (
|
||||
hasattr(e, "status_code") and e.status_code == 400
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
self.print_verbose(f"Trying to fallback b/w models")
|
||||
if (
|
||||
isinstance(e, litellm.ContextWindowExceededError)
|
||||
|
@ -846,7 +850,7 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is None
|
||||
|
@ -864,12 +868,12 @@ class Router:
|
|||
min_timeout=self.retry_after,
|
||||
)
|
||||
await asyncio.sleep(timeout)
|
||||
elif (
|
||||
hasattr(original_exception, "status_code")
|
||||
and hasattr(original_exception, "response")
|
||||
and litellm._should_retry(status_code=original_exception.status_code)
|
||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||
status_code=original_exception.status_code
|
||||
):
|
||||
if hasattr(original_exception, "response") and hasattr(
|
||||
original_exception.response, "headers"
|
||||
):
|
||||
if hasattr(original_exception.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(
|
||||
remaining_retries=num_retries,
|
||||
max_retries=num_retries,
|
||||
|
|
137
litellm/tests/test_router_policy_violation.py
Normal file
137
litellm/tests/test_router_policy_violation.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
#### What this tests ####
|
||||
# This tests if the router sends back a policy violation, without retries
|
||||
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
success: bool = False
|
||||
failure: bool = False
|
||||
previous_models: int = 0
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
print(
|
||||
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
|
||||
)
|
||||
self.previous_models += len(
|
||||
kwargs["litellm_params"]["metadata"]["previous_models"]
|
||||
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||
print(f"self.previous_models: {self.previous_models}")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(
|
||||
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
|
||||
)
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
|
||||
|
||||
kwargs = {
|
||||
"model": "azure/gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "vorrei vedere la cosa più bella ad Ercolano. Qual’è?",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_fallbacks():
|
||||
litellm.set_verbose = False
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 1000000,
|
||||
"rpm": 9000,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 1000000,
|
||||
"rpm": 9000,
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
num_retries=3,
|
||||
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||
# context_window_fallbacks=[
|
||||
# {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
|
||||
# {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
|
||||
# ],
|
||||
set_verbose=False,
|
||||
)
|
||||
customHandler = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
try:
|
||||
response = await router.acompletion(**kwargs)
|
||||
pytest.fail(
|
||||
f"An exception occurred: {e}"
|
||||
) # should've raised azure policy error
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
await asyncio.sleep(
|
||||
0.05
|
||||
) # allow a delay as success_callbacks are on a separate thread
|
||||
assert customHandler.previous_models == 0 # 0 retries, 0 fallback
|
||||
router.reset()
|
||||
finally:
|
||||
router.reset()
|
Loading…
Add table
Add a link
Reference in a new issue