forked from phoenix/litellm-mirror
fix(router.py): unify retry timeout logic across sync + async function_with_retries
This commit is contained in:
parent
285a3733a9
commit
87ff26ff27
2 changed files with 119 additions and 65 deletions
|
@ -1418,6 +1418,13 @@ class Router:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
async def _async_router_should_retry(
|
||||||
|
self, e: Exception, remaining_retries: int, num_retries: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calculate back-off, then retry
|
||||||
|
"""
|
||||||
|
|
||||||
async def async_function_with_retries(self, *args, **kwargs):
|
async def async_function_with_retries(self, *args, **kwargs):
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
|
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
|
||||||
|
@ -1450,40 +1457,47 @@ class Router:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
### RETRY
|
### RETRY
|
||||||
#### check if it should retry + back-off if required
|
#### check if it should retry + back-off if required
|
||||||
if "No models available" in str(
|
# if "No models available" in str(
|
||||||
e
|
# e
|
||||||
) or RouterErrors.no_deployments_available.value in str(e):
|
# ) or RouterErrors.no_deployments_available.value in str(e):
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
# remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
await asyncio.sleep(timeout)
|
# await asyncio.sleep(timeout)
|
||||||
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
# elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
||||||
raise e # don't wait to retry if deployment hits user-defined rate-limit
|
# raise e # don't wait to retry if deployment hits user-defined rate-limit
|
||||||
|
|
||||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
# elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||||
status_code=original_exception.status_code
|
# status_code=original_exception.status_code
|
||||||
):
|
# ):
|
||||||
if hasattr(original_exception, "response") and hasattr(
|
# if hasattr(original_exception, "response") and hasattr(
|
||||||
original_exception.response, "headers"
|
# original_exception.response, "headers"
|
||||||
):
|
# ):
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
# remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
response_headers=original_exception.response.headers,
|
# response_headers=original_exception.response.headers,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
# remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
await asyncio.sleep(timeout)
|
# await asyncio.sleep(timeout)
|
||||||
else:
|
# else:
|
||||||
raise original_exception
|
# raise original_exception
|
||||||
|
|
||||||
|
### RETRY
|
||||||
|
_timeout = self._router_should_retry(
|
||||||
|
e=original_exception,
|
||||||
|
remaining_retries=num_retries,
|
||||||
|
num_retries=num_retries,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_timeout)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
if num_retries > 0:
|
if num_retries > 0:
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||||
|
@ -1505,34 +1519,37 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
if "No models available" in str(e):
|
# if "No models available" in str(e):
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
# remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
await asyncio.sleep(timeout)
|
# await asyncio.sleep(timeout)
|
||||||
elif (
|
# elif (
|
||||||
hasattr(e, "status_code")
|
# hasattr(e, "status_code")
|
||||||
and hasattr(e, "response")
|
# and hasattr(e, "response")
|
||||||
and litellm._should_retry(status_code=e.status_code)
|
# and litellm._should_retry(status_code=e.status_code)
|
||||||
):
|
# ):
|
||||||
if hasattr(e.response, "headers"):
|
# if hasattr(e.response, "headers"):
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
# remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
response_headers=e.response.headers,
|
# response_headers=e.response.headers,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
timeout = litellm._calculate_retry_after(
|
# timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
# remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
# max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
# min_timeout=self.retry_after,
|
||||||
)
|
# )
|
||||||
await asyncio.sleep(timeout)
|
_timeout = self._router_should_retry(
|
||||||
else:
|
e=original_exception,
|
||||||
raise e
|
remaining_retries=remaining_retries,
|
||||||
|
num_retries=num_retries,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_timeout)
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
def function_with_fallbacks(self, *args, **kwargs):
|
def function_with_fallbacks(self, *args, **kwargs):
|
||||||
|
@ -1625,7 +1642,7 @@ class Router:
|
||||||
|
|
||||||
def _router_should_retry(
|
def _router_should_retry(
|
||||||
self, e: Exception, remaining_retries: int, num_retries: int
|
self, e: Exception, remaining_retries: int, num_retries: int
|
||||||
):
|
) -> int | float:
|
||||||
"""
|
"""
|
||||||
Calculate back-off, then retry
|
Calculate back-off, then retry
|
||||||
"""
|
"""
|
||||||
|
@ -1636,14 +1653,13 @@ class Router:
|
||||||
response_headers=e.response.headers,
|
response_headers=e.response.headers,
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
time.sleep(timeout)
|
|
||||||
else:
|
else:
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
time.sleep(timeout)
|
return timeout
|
||||||
|
|
||||||
def function_with_retries(self, *args, **kwargs):
|
def function_with_retries(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1677,11 +1693,12 @@ class Router:
|
||||||
if num_retries > 0:
|
if num_retries > 0:
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||||
### RETRY
|
### RETRY
|
||||||
self._router_should_retry(
|
_timeout = self._router_should_retry(
|
||||||
e=original_exception,
|
e=original_exception,
|
||||||
remaining_retries=num_retries,
|
remaining_retries=num_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
)
|
)
|
||||||
|
time.sleep(_timeout)
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
||||||
|
@ -1695,11 +1712,12 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
self._router_should_retry(
|
_timeout = self._router_should_retry(
|
||||||
e=e,
|
e=e,
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
)
|
)
|
||||||
|
time.sleep(_timeout)
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
### HELPER FUNCTIONS
|
### HELPER FUNCTIONS
|
||||||
|
|
|
@ -104,6 +104,42 @@ def test_router_timeout_init(timeout, ssl_verify):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_retries(sync_mode):
|
||||||
|
"""
|
||||||
|
- make sure retries work as expected
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "bad-key"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, num_retries=2)
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"mistral_api_base",
|
"mistral_api_base",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue