Add attempted-retries and timeout values to response headers + more testing (#7926)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s

* feat(router.py): add retry headers to response

makes it easy to add testing to ensure model-specific retries are respected

* fix(add_retry_headers.py): clarify attempted retries vs. max retries

* test(test_fallbacks.py): add test for checking if max retries set for model is respected

* test(test_fallbacks.py): assert values for attempted retries and max retries are as expected

* fix(utils.py): return timeout in litellm proxy response headers

* test(test_fallbacks.py): add test to assert model specific timeout used on timeout error

* test: add bad model with timeout to proxy

* fix: fix linting error

* fix(router.py): fix get model list from model alias

* test: loosen test restriction - account for other events on proxy
This commit is contained in:
Krish Dholakia 2025-01-22 22:19:44 -08:00 committed by GitHub
parent bc546d82a1
commit 513b1904ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 245 additions and 31 deletions

View file

@ -669,6 +669,7 @@ def _get_wrapper_num_retries(
Get the number of retries from the kwargs and the retry policy.
Used for the wrapper functions.
"""
num_retries = kwargs.get("num_retries", None)
if num_retries is None:
num_retries = litellm.num_retries
@ -684,6 +685,21 @@ def _get_wrapper_num_retries(
return num_retries, kwargs
def _get_wrapper_timeout(
kwargs: Dict[str, Any], exception: Exception
) -> Optional[Union[float, int, httpx.Timeout]]:
"""
Get the timeout from the kwargs
Used for the wrapper functions.
"""
timeout = cast(
Optional[Union[float, int, httpx.Timeout]], kwargs.get("timeout", None)
)
return timeout
def client(original_function): # noqa: PLR0915
rules_obj = Rules()
@ -1243,9 +1259,11 @@ def client(original_function): # noqa: PLR0915
_is_litellm_router_call = "model_group" in kwargs.get(
"metadata", {}
) # check if call from litellm.router/proxy
if (
num_retries and not _is_litellm_router_call
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
try:
litellm.num_retries = (
None # set retries to None to prevent infinite loops
@ -1266,6 +1284,7 @@ def client(original_function): # noqa: PLR0915
and context_window_fallback_dict
and model in context_window_fallback_dict
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model] # type: ignore
else:
@ -1275,6 +1294,9 @@ def client(original_function): # noqa: PLR0915
setattr(
e, "num_retries", num_retries
) ## IMPORTANT: returns the deployment's num_retries to the router
timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e)
setattr(e, "timeout", timeout)
raise e
is_coroutine = inspect.iscoroutinefunction(original_function)