forked from phoenix/litellm-mirror
Merge pull request #3376 from BerriAI/litellm_routing_logic
fix(router.py): unify retry timeout logic across sync + async function_with_retries
This commit is contained in:
commit
9f55a99e98
7 changed files with 265 additions and 82 deletions
|
@ -360,7 +360,7 @@ def mock_completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List,
|
messages: List,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
mock_response: str = "This is a mock request",
|
mock_response: Union[str, Exception] = "This is a mock request",
|
||||||
logging=None,
|
logging=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -387,6 +387,20 @@ def mock_completion(
|
||||||
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
## LOGGING
|
||||||
|
if logging is not None:
|
||||||
|
logging.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="mock-key",
|
||||||
|
)
|
||||||
|
if isinstance(mock_response, Exception):
|
||||||
|
raise litellm.APIError(
|
||||||
|
status_code=500, # type: ignore
|
||||||
|
message=str(mock_response),
|
||||||
|
llm_provider="openai", # type: ignore
|
||||||
|
model=model, # type: ignore
|
||||||
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
|
)
|
||||||
model_response = ModelResponse(stream=stream)
|
model_response = ModelResponse(stream=stream)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
|
|
@ -1450,40 +1450,47 @@ class Router:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
### RETRY
|
### RETRY
|
||||||
#### check if it should retry + back-off if required
|
#### check if it should retry + back-off if required
|
||||||
if "No models available" in str(
|
# if "No models available" in str(
|
||||||
e
|
# e
|
||||||
) or RouterErrors.no_deployments_available.value in str(e):
|
# ) or RouterErrors.no_deployments_available.value in str(e):
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
# remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
await asyncio.sleep(timeout)
|
# await asyncio.sleep(timeout)
|
||||||
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
# elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
||||||
raise e # don't wait to retry if deployment hits user-defined rate-limit
|
# raise e # don't wait to retry if deployment hits user-defined rate-limit
|
||||||
|
|
||||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
# elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||||
status_code=original_exception.status_code
|
# status_code=original_exception.status_code
|
||||||
):
|
# ):
|
||||||
if hasattr(original_exception, "response") and hasattr(
|
# if hasattr(original_exception, "response") and hasattr(
|
||||||
original_exception.response, "headers"
|
# original_exception.response, "headers"
|
||||||
):
|
# ):
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
# remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
response_headers=original_exception.response.headers,
|
# response_headers=original_exception.response.headers,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
# remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
await asyncio.sleep(timeout)
|
# await asyncio.sleep(timeout)
|
||||||
else:
|
# else:
|
||||||
raise original_exception
|
# raise original_exception
|
||||||
|
|
||||||
|
### RETRY
|
||||||
|
_timeout = self._router_should_retry(
|
||||||
|
e=original_exception,
|
||||||
|
remaining_retries=num_retries,
|
||||||
|
num_retries=num_retries,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_timeout)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
if num_retries > 0:
|
if num_retries > 0:
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||||
|
@ -1505,34 +1512,12 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
if "No models available" in str(e):
|
_timeout = self._router_should_retry(
|
||||||
timeout = litellm._calculate_retry_after(
|
e=original_exception,
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
num_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
)
|
||||||
)
|
await asyncio.sleep(_timeout)
|
||||||
await asyncio.sleep(timeout)
|
|
||||||
elif (
|
|
||||||
hasattr(e, "status_code")
|
|
||||||
and hasattr(e, "response")
|
|
||||||
and litellm._should_retry(status_code=e.status_code)
|
|
||||||
):
|
|
||||||
if hasattr(e.response, "headers"):
|
|
||||||
timeout = litellm._calculate_retry_after(
|
|
||||||
remaining_retries=remaining_retries,
|
|
||||||
max_retries=num_retries,
|
|
||||||
response_headers=e.response.headers,
|
|
||||||
min_timeout=self.retry_after,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
timeout = litellm._calculate_retry_after(
|
|
||||||
remaining_retries=remaining_retries,
|
|
||||||
max_retries=num_retries,
|
|
||||||
min_timeout=self.retry_after,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(timeout)
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
def function_with_fallbacks(self, *args, **kwargs):
|
def function_with_fallbacks(self, *args, **kwargs):
|
||||||
|
@ -1625,7 +1610,7 @@ class Router:
|
||||||
|
|
||||||
def _router_should_retry(
|
def _router_should_retry(
|
||||||
self, e: Exception, remaining_retries: int, num_retries: int
|
self, e: Exception, remaining_retries: int, num_retries: int
|
||||||
):
|
) -> Union[int, float]:
|
||||||
"""
|
"""
|
||||||
Calculate back-off, then retry
|
Calculate back-off, then retry
|
||||||
"""
|
"""
|
||||||
|
@ -1636,14 +1621,13 @@ class Router:
|
||||||
response_headers=e.response.headers,
|
response_headers=e.response.headers,
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
time.sleep(timeout)
|
|
||||||
else:
|
else:
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
time.sleep(timeout)
|
return timeout
|
||||||
|
|
||||||
def function_with_retries(self, *args, **kwargs):
|
def function_with_retries(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1658,6 +1642,7 @@ class Router:
|
||||||
context_window_fallbacks = kwargs.pop(
|
context_window_fallbacks = kwargs.pop(
|
||||||
"context_window_fallbacks", self.context_window_fallbacks
|
"context_window_fallbacks", self.context_window_fallbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = original_function(*args, **kwargs)
|
response = original_function(*args, **kwargs)
|
||||||
|
@ -1677,11 +1662,12 @@ class Router:
|
||||||
if num_retries > 0:
|
if num_retries > 0:
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||||
### RETRY
|
### RETRY
|
||||||
self._router_should_retry(
|
_timeout = self._router_should_retry(
|
||||||
e=original_exception,
|
e=original_exception,
|
||||||
remaining_retries=num_retries,
|
remaining_retries=num_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
)
|
)
|
||||||
|
time.sleep(_timeout)
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
||||||
|
@ -1695,11 +1681,12 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
self._router_should_retry(
|
_timeout = self._router_should_retry(
|
||||||
e=e,
|
e=e,
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
)
|
)
|
||||||
|
time.sleep(_timeout)
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
### HELPER FUNCTIONS
|
### HELPER FUNCTIONS
|
||||||
|
@ -1733,10 +1720,11 @@ class Router:
|
||||||
) # i.e. azure
|
) # i.e. azure
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
||||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||||
|
|
||||||
if isinstance(_model_info, dict):
|
if isinstance(_model_info, dict):
|
||||||
deployment_id = _model_info.get("id", None)
|
deployment_id = _model_info.get("id", None)
|
||||||
self._set_cooldown_deployments(
|
self._set_cooldown_deployments(
|
||||||
deployment_id
|
exception_status=exception_status, deployment=deployment_id
|
||||||
) # setting deployment_id in cooldown deployments
|
) # setting deployment_id in cooldown deployments
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
model_name = f"{custom_llm_provider}/{model_name}"
|
model_name = f"{custom_llm_provider}/{model_name}"
|
||||||
|
@ -1796,9 +1784,15 @@ class Router:
|
||||||
key=rpm_key, value=request_count, local_only=True
|
key=rpm_key, value=request_count, local_only=True
|
||||||
) # don't change existing ttl
|
) # don't change existing ttl
|
||||||
|
|
||||||
def _set_cooldown_deployments(self, deployment: Optional[str] = None):
|
def _set_cooldown_deployments(
|
||||||
|
self, exception_status: Union[str, int], deployment: Optional[str] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
the exception is not one that should be immediately retried (e.g. 401)
|
||||||
"""
|
"""
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
@ -1815,7 +1809,20 @@ class Router:
|
||||||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
||||||
)
|
)
|
||||||
cooldown_time = self.cooldown_time or 1
|
cooldown_time = self.cooldown_time or 1
|
||||||
if updated_fails > self.allowed_fails:
|
|
||||||
|
if isinstance(exception_status, str):
|
||||||
|
try:
|
||||||
|
exception_status = int(exception_status)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
"Unable to cast exception status to int {}. Defaulting to status=500.".format(
|
||||||
|
exception_status
|
||||||
|
)
|
||||||
|
)
|
||||||
|
exception_status = 500
|
||||||
|
_should_retry = litellm._should_retry(status_code=exception_status)
|
||||||
|
|
||||||
|
if updated_fails > self.allowed_fails or _should_retry == False:
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
cached_value = self.cache.get_cache(key=cooldown_key)
|
||||||
|
|
|
@ -19,6 +19,7 @@ def setup_and_teardown():
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the project directory to the system path
|
) # Adds the project directory to the system path
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
importlib.reload(litellm)
|
importlib.reload(litellm)
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
@ -104,6 +104,42 @@ def test_router_timeout_init(timeout, ssl_verify):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_retries(sync_mode):
|
||||||
|
"""
|
||||||
|
- make sure retries work as expected
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "bad-key"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, num_retries=2)
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"mistral_api_base",
|
"mistral_api_base",
|
||||||
[
|
[
|
||||||
|
@ -1118,6 +1154,7 @@ def test_consistent_model_id():
|
||||||
assert id1 == id2
|
assert id1 == id2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
def test_reading_keys_os_environ():
|
def test_reading_keys_os_environ():
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
@ -1217,6 +1254,7 @@ def test_reading_keys_os_environ():
|
||||||
# test_reading_keys_os_environ()
|
# test_reading_keys_os_environ()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
def test_reading_openai_keys_os_environ():
|
def test_reading_openai_keys_os_environ():
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ def test_async_fallbacks(caplog):
|
||||||
router = Router(
|
router = Router(
|
||||||
model_list=model_list,
|
model_list=model_list,
|
||||||
fallbacks=[{"gpt-3.5-turbo": ["azure/gpt-3.5-turbo"]}],
|
fallbacks=[{"gpt-3.5-turbo": ["azure/gpt-3.5-turbo"]}],
|
||||||
|
num_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_message = "Hello, how are you?"
|
user_message = "Hello, how are you?"
|
||||||
|
@ -82,6 +83,7 @@ def test_async_fallbacks(caplog):
|
||||||
# - error request, falling back notice, success notice
|
# - error request, falling back notice, success notice
|
||||||
expected_logs = [
|
expected_logs = [
|
||||||
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
|
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
|
||||||
|
"litellm.acompletion(model=None)\x1b[31m Exception No deployments available for selected model, passed model=gpt-3.5-turbo\x1b[0m",
|
||||||
"Falling back to model_group = azure/gpt-3.5-turbo",
|
"Falling back to model_group = azure/gpt-3.5-turbo",
|
||||||
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
||||||
]
|
]
|
||||||
|
|
|
@ -22,10 +22,10 @@ class MyCustomHandler(CustomLogger):
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
print(f"Pre-API Call")
|
print(f"Pre-API Call")
|
||||||
print(
|
print(
|
||||||
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
|
f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}"
|
||||||
)
|
)
|
||||||
self.previous_models += len(
|
self.previous_models = len(
|
||||||
kwargs["litellm_params"]["metadata"]["previous_models"]
|
kwargs["litellm_params"]["metadata"].get("previous_models", [])
|
||||||
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||||
print(f"self.previous_models: {self.previous_models}")
|
print(f"self.previous_models: {self.previous_models}")
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ def test_sync_fallbacks():
|
||||||
response = router.completion(**kwargs)
|
response = router.completion(**kwargs)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4
|
||||||
|
|
||||||
print("Passed ! Test router_fallbacks: test_sync_fallbacks()")
|
print("Passed ! Test router_fallbacks: test_sync_fallbacks()")
|
||||||
router.reset()
|
router.reset()
|
||||||
|
@ -140,7 +140,7 @@ def test_sync_fallbacks():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_fallbacks():
|
async def test_async_fallbacks():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = True
|
||||||
model_list = [
|
model_list = [
|
||||||
{ # list of model deployments
|
{ # list of model deployments
|
||||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
@ -209,12 +209,13 @@ async def test_async_fallbacks():
|
||||||
user_message = "Hello, how are you?"
|
user_message = "Hello, how are you?"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
response = await router.acompletion(**kwargs)
|
response = await router.acompletion(**kwargs)
|
||||||
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
0.05
|
0.05
|
||||||
) # allow a delay as success_callbacks are on a separate thread
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
@ -258,7 +259,6 @@ def test_sync_fallbacks_embeddings():
|
||||||
model_list=model_list,
|
model_list=model_list,
|
||||||
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
|
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
|
||||||
set_verbose=False,
|
set_verbose=False,
|
||||||
num_retries=0,
|
|
||||||
)
|
)
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
|
@ -269,7 +269,7 @@ def test_sync_fallbacks_embeddings():
|
||||||
response = router.embedding(**kwargs)
|
response = router.embedding(**kwargs)
|
||||||
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
@ -323,7 +323,7 @@ async def test_async_fallbacks_embeddings():
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
0.05
|
0.05
|
||||||
) # allow a delay as success_callbacks are on a separate thread
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
@ -394,7 +394,7 @@ def test_dynamic_fallbacks_sync():
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list, set_verbose=True, num_retries=0)
|
router = Router(model_list=model_list, set_verbose=True)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
kwargs["model"] = "azure/gpt-3.5-turbo"
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
@ -402,7 +402,7 @@ def test_dynamic_fallbacks_sync():
|
||||||
response = router.completion(**kwargs)
|
response = router.completion(**kwargs)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {e}")
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
@ -488,7 +488,7 @@ async def test_dynamic_fallbacks_async():
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
0.05
|
0.05
|
||||||
) # allow a delay as success_callbacks are on a separate thread
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {e}")
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
@ -573,7 +573,7 @@ async def test_async_fallbacks_streaming():
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
0.05
|
0.05
|
||||||
) # allow a delay as success_callbacks are on a separate thread
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
|
121
litellm/tests/test_router_retries.py
Normal file
121
litellm/tests/test_router_retries.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests calling router with fallback models
|
||||||
|
|
||||||
|
import sys, os, time
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
success: bool = False
|
||||||
|
failure: bool = False
|
||||||
|
previous_models: int = 0
|
||||||
|
|
||||||
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
print(f"Pre-API Call")
|
||||||
|
print(
|
||||||
|
f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}"
|
||||||
|
)
|
||||||
|
self.previous_models = len(
|
||||||
|
kwargs["litellm_params"]["metadata"].get("previous_models", [])
|
||||||
|
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||||
|
print(f"self.previous_models: {self.previous_models}")
|
||||||
|
|
||||||
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(
|
||||||
|
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Stream")
|
||||||
|
|
||||||
|
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Stream")
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Failure")
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Test sync + async
|
||||||
|
|
||||||
|
- Authorization Errors
|
||||||
|
- Random API Error
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.parametrize("error_type", ["Authorization Error", "API Error"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_retries_errors(sync_mode, error_type):
|
||||||
|
"""
|
||||||
|
- Auth Error -> 0 retries
|
||||||
|
- API Error -> 2 retries
|
||||||
|
"""
|
||||||
|
|
||||||
|
_api_key = (
|
||||||
|
"bad-key" if error_type == "Authorization Error" else os.getenv("AZURE_API_KEY")
|
||||||
|
)
|
||||||
|
print(f"_api_key: {_api_key}")
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": _api_key,
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, allowed_fails=3)
|
||||||
|
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "azure/gpt-3.5-turbo",
|
||||||
|
"messages": messages,
|
||||||
|
"mock_response": (
|
||||||
|
None
|
||||||
|
if error_type == "Authorization Error"
|
||||||
|
else Exception("Invalid Request")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
response = router.completion(**kwargs)
|
||||||
|
else:
|
||||||
|
response = await router.acompletion(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(
|
||||||
|
0.05
|
||||||
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
|
|
||||||
|
if error_type == "Authorization Error":
|
||||||
|
assert customHandler.previous_models == 0 # 0 retries
|
||||||
|
else:
|
||||||
|
assert customHandler.previous_models == 2 # 2 retries
|
Loading…
Add table
Add a link
Reference in a new issue