Add attempted-retries and timeout values to response headers + more testing (#7926)

* feat(router.py): add retry headers to response

makes it easy to add testing to ensure model-specific retries are respected

* fix(add_retry_headers.py): clarify attempted retries vs. max retries

* test(test_fallbacks.py): add test for checking if max retries set for model is respected

* test(test_fallbacks.py): assert values for attempted retries and max retries are as expected

* fix(utils.py): return timeout in litellm proxy response headers

* test(test_fallbacks.py): add test to assert model specific timeout used on timeout error

* test: add bad model with timeout to proxy

* fix: fix linting error

* fix(router.py): fix get model list from model alias

* test: loosen test restriction - account for other events on proxy
This commit is contained in:
Krish Dholakia 2025-01-22 22:19:44 -08:00 committed by GitHub
parent a1a81fc0f3
commit b286bab075
9 changed files with 245 additions and 31 deletions

View file

@ -787,9 +787,10 @@ def get_custom_headers(
hidden_params: Optional[dict] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
**kwargs,
) -> dict:
exclude_values = {"", None}
exclude_values = {"", None, "None"}
hidden_params = hidden_params or {}
headers = {
"x-litellm-call-id": call_id,
@ -812,6 +813,7 @@ def get_custom_headers(
if fastest_response_batch_completion is not None
else None
),
"x-litellm-timeout": str(timeout) if timeout is not None else None,
**{k: str(v) for k, v in kwargs.items()},
}
if request_data:
@ -3638,14 +3640,28 @@ async def chat_completion( # noqa: PLR0915
litellm_debug_info,
)
timeout = getattr(
e, "timeout", None
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
timeout=timeout,
)
headers = getattr(e, "headers", {}) or {}
headers.update(custom_headers)
if isinstance(e, HTTPException):
# print("e.headers={}".format(e.headers))
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
headers=getattr(e, "headers", {}),
headers=headers,
)
error_msg = f"{str(e)}"
raise ProxyException(
@ -3653,7 +3669,7 @@ async def chat_completion( # noqa: PLR0915
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
headers=getattr(e, "headers", {}),
headers=headers,
)