feat(router.py): support 'retry_after' param, to set min timeout before retrying a failed request (default 0)

This commit is contained in:
Krrish Dholakia 2023-12-29 15:18:17 +05:30
parent 4a028d012a
commit 4882325c35

View file

@ -90,6 +90,7 @@ class Router:
allowed_fails: Optional[int] = None, allowed_fails: Optional[int] = None,
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
retry_after: int = 0, # min time to wait before retrying a failed request
routing_strategy: Literal[ routing_strategy: Literal[
"simple-shuffle", "simple-shuffle",
"least-busy", "least-busy",
@ -115,6 +116,7 @@ class Router:
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries or 0 self.num_retries = num_retries or litellm.num_retries or 0
self.timeout = timeout or litellm.request_timeout self.timeout = timeout or litellm.request_timeout
self.retry_after = retry_after
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks self.fallbacks = fallbacks or litellm.fallbacks
self.context_window_fallbacks = ( self.context_window_fallbacks = (
@ -776,7 +778,9 @@ class Router:
#### check if it should retry + back-off if required #### check if it should retry + back-off if required
if "No models available" in str(e): if "No models available" in str(e):
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=num_retries, max_retries=num_retries remaining_retries=num_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
elif ( elif (
@ -789,10 +793,13 @@ class Router:
remaining_retries=num_retries, remaining_retries=num_retries,
max_retries=num_retries, max_retries=num_retries,
response_headers=original_exception.response.headers, response_headers=original_exception.response.headers,
min_timeout=self.retry_after,
) )
else: else:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=num_retries, max_retries=num_retries remaining_retries=num_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
else: else:
@ -823,7 +830,7 @@ class Router:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
min_timeout=1, min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
elif ( elif (
@ -836,11 +843,13 @@ class Router:
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
response_headers=e.response.headers, response_headers=e.response.headers,
min_timeout=self.retry_after,
) )
else: else:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
else: else:
@ -972,7 +981,7 @@ class Router:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
min_timeout=1, min_timeout=self.retry_after,
) )
time.sleep(timeout) time.sleep(timeout)
elif ( elif (
@ -985,11 +994,13 @@ class Router:
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
response_headers=e.response.headers, response_headers=e.response.headers,
min_timeout=self.retry_after,
) )
else: else:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, max_retries=num_retries,
min_timeout=self.retry_after,
) )
time.sleep(timeout) time.sleep(timeout)
else: else: