fix(router.py): unify retry timeout logic across sync + async function_with_retries

This commit is contained in:
Krrish Dholakia 2024-04-30 15:23:19 -07:00
parent 285a3733a9
commit 87ff26ff27
2 changed files with 119 additions and 65 deletions

View file

@ -1418,6 +1418,13 @@ class Router:
traceback.print_exc() traceback.print_exc()
raise original_exception raise original_exception
async def _async_router_should_retry(
self, e: Exception, remaining_retries: int, num_retries: int
):
"""
Calculate back-off, then retry
"""
async def async_function_with_retries(self, *args, **kwargs): async def async_function_with_retries(self, *args, **kwargs):
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside async function with retries: args - {args}; kwargs - {kwargs}" f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
@ -1450,40 +1457,47 @@ class Router:
raise original_exception raise original_exception
### RETRY ### RETRY
#### check if it should retry + back-off if required #### check if it should retry + back-off if required
if "No models available" in str( # if "No models available" in str(
e # e
) or RouterErrors.no_deployments_available.value in str(e): # ) or RouterErrors.no_deployments_available.value in str(e):
timeout = litellm._calculate_retry_after( # timeout = litellm._calculate_retry_after(
remaining_retries=num_retries, # remaining_retries=num_retries,
max_retries=num_retries, # max_retries=num_retries,
min_timeout=self.retry_after, # min_timeout=self.retry_after,
) # )
await asyncio.sleep(timeout) # await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e): # elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit # raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry( # elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code # status_code=original_exception.status_code
): # ):
if hasattr(original_exception, "response") and hasattr( # if hasattr(original_exception, "response") and hasattr(
original_exception.response, "headers" # original_exception.response, "headers"
): # ):
timeout = litellm._calculate_retry_after( # timeout = litellm._calculate_retry_after(
remaining_retries=num_retries, # remaining_retries=num_retries,
max_retries=num_retries, # max_retries=num_retries,
response_headers=original_exception.response.headers, # response_headers=original_exception.response.headers,
min_timeout=self.retry_after, # min_timeout=self.retry_after,
) # )
else: # else:
timeout = litellm._calculate_retry_after( # timeout = litellm._calculate_retry_after(
remaining_retries=num_retries, # remaining_retries=num_retries,
max_retries=num_retries, # max_retries=num_retries,
min_timeout=self.retry_after, # min_timeout=self.retry_after,
) # )
await asyncio.sleep(timeout) # await asyncio.sleep(timeout)
else: # else:
raise original_exception # raise original_exception
### RETRY
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
## LOGGING ## LOGGING
if num_retries > 0: if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
@ -1505,34 +1519,37 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
if "No models available" in str(e): # if "No models available" in str(e):
timeout = litellm._calculate_retry_after( # timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, # remaining_retries=remaining_retries,
max_retries=num_retries, # max_retries=num_retries,
min_timeout=self.retry_after, # min_timeout=self.retry_after,
) # )
await asyncio.sleep(timeout) # await asyncio.sleep(timeout)
elif ( # elif (
hasattr(e, "status_code") # hasattr(e, "status_code")
and hasattr(e, "response") # and hasattr(e, "response")
and litellm._should_retry(status_code=e.status_code) # and litellm._should_retry(status_code=e.status_code)
): # ):
if hasattr(e.response, "headers"): # if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after( # timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, # remaining_retries=remaining_retries,
max_retries=num_retries, # max_retries=num_retries,
response_headers=e.response.headers, # response_headers=e.response.headers,
min_timeout=self.retry_after, # min_timeout=self.retry_after,
) # )
else: # else:
timeout = litellm._calculate_retry_after( # timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, # remaining_retries=remaining_retries,
max_retries=num_retries, # max_retries=num_retries,
min_timeout=self.retry_after, # min_timeout=self.retry_after,
) # )
await asyncio.sleep(timeout) _timeout = self._router_should_retry(
else: e=original_exception,
raise e remaining_retries=remaining_retries,
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
raise original_exception raise original_exception
def function_with_fallbacks(self, *args, **kwargs): def function_with_fallbacks(self, *args, **kwargs):
@ -1625,7 +1642,7 @@ class Router:
def _router_should_retry( def _router_should_retry(
self, e: Exception, remaining_retries: int, num_retries: int self, e: Exception, remaining_retries: int, num_retries: int
): ) -> int | float:
""" """
Calculate back-off, then retry Calculate back-off, then retry
""" """
@ -1636,14 +1653,13 @@ class Router:
response_headers=e.response.headers, response_headers=e.response.headers,
min_timeout=self.retry_after, min_timeout=self.retry_after,
) )
time.sleep(timeout)
else: else:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
min_timeout=self.retry_after, min_timeout=self.retry_after,
) )
time.sleep(timeout) return timeout
def function_with_retries(self, *args, **kwargs): def function_with_retries(self, *args, **kwargs):
""" """
@ -1677,11 +1693,12 @@ class Router:
if num_retries > 0: if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY ### RETRY
self._router_should_retry( _timeout = self._router_should_retry(
e=original_exception, e=original_exception,
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
) )
time.sleep(_timeout)
for current_attempt in range(num_retries): for current_attempt in range(num_retries):
verbose_router_logger.debug( verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
@ -1695,11 +1712,12 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
self._router_should_retry( _timeout = self._router_should_retry(
e=e, e=e,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
) )
time.sleep(_timeout)
raise original_exception raise original_exception
### HELPER FUNCTIONS ### HELPER FUNCTIONS

View file

@ -104,6 +104,42 @@ def test_router_timeout_init(timeout, ssl_verify):
) )
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_router_retries(sync_mode):
"""
- make sure retries work as expected
"""
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "bad-key"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
},
]
router = Router(model_list=model_list, num_retries=2)
if sync_mode:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
else:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mistral_api_base", "mistral_api_base",
[ [