Add attempted-retries and timeout values to response headers + more testing (#7926)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s

* feat(router.py): add retry headers to response

makes it easy to add testing to ensure model-specific retries are respected

* fix(add_retry_headers.py): clarify attempted retries vs. max retries

* test(test_fallbacks.py): add test for checking if max retries set for model is respected

* test(test_fallbacks.py): assert values for attempted retries and max retries are as expected

* fix(utils.py): return timeout in litellm proxy response headers

* test(test_fallbacks.py): add test to assert model specific timeout used on timeout error

* test: add bad model with timeout to proxy

* fix: fix linting error

* fix(router.py): fix get model list from model alias

* test: loosen test restriction - account for other events on proxy
This commit is contained in:
Krish Dholakia 2025-01-22 22:19:44 -08:00 committed by GitHub
parent bc546d82a1
commit 513b1904ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 245 additions and 31 deletions

View file

@ -57,6 +57,7 @@ from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_strategy.simple_shuffle import simple_shuffle
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
from litellm.router_utils.add_retry_headers import add_retry_headers_to_response
from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name,
replace_model_in_jsonl,
@ -3090,12 +3091,15 @@ class Router:
)
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await self.make_call(original_function, *args, **kwargs)
response = add_retry_headers_to_response(
response=response, attempted_retries=0, max_retries=None
)
return response
except Exception as e:
current_attempt = None
original_exception = e
deployment_num_retries = getattr(e, "num_retries", None)
if deployment_num_retries is not None and isinstance(
deployment_num_retries, int
):
@ -3156,6 +3160,12 @@ class Router:
response
): # async errors are often returned as coroutines
response = await response
response = add_retry_headers_to_response(
response=response,
attempted_retries=current_attempt + 1,
max_retries=num_retries,
)
return response
except Exception as e:
@ -3214,6 +3224,15 @@ class Router:
mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
"mock_testing_rate_limit_error", None
)
available_models = self.get_model_list(model_name=model_group)
num_retries: Optional[int] = None
if available_models is not None and len(available_models) == 1:
num_retries = cast(
Optional[int], available_models[0]["litellm_params"].get("num_retries")
)
if (
mock_testing_rate_limit_error is not None
and mock_testing_rate_limit_error is True
@ -3225,6 +3244,7 @@ class Router:
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
num_retries=num_retries,
)
def should_retry_this_error(
@ -4776,6 +4796,37 @@ class Router:
model_names.append(m["model_name"])
return model_names
def get_model_list_from_model_alias(
self, model_name: Optional[str] = None
) -> List[DeploymentTypedDict]:
"""
Helper function to get model list from model alias.
Used by `.get_model_list` to get model list from model alias.
"""
returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items():
if model_name is not None and model_alias != model_name:
continue
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=_router_model_name, model_alias=model_alias
)
)
return returned_models
def get_model_list(
self, model_name: Optional[str] = None
) -> Optional[List[DeploymentTypedDict]]:
@ -4789,24 +4840,9 @@ class Router:
returned_models.extend(self._get_all_deployments(model_name=model_name))
if hasattr(self, "model_group_alias"):
for model_alias, model_value in self.model_group_alias.items():
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=_router_model_name, model_alias=model_alias
)
)
returned_models.extend(
self.get_model_list_from_model_alias(model_name=model_name)
)
if len(returned_models) == 0: # check if wildcard route
potential_wildcard_models = self.pattern_router.route(model_name)