Add attempted-retries and timeout values to response headers + more testing (#7926)

* feat(router.py): add retry headers to response

makes it easy to add testing to ensure model-specific retries are respected

* fix(add_retry_headers.py): clarify attempted retries vs. max retries

* test(test_fallbacks.py): add test for checking if max retries set for model is respected

* test(test_fallbacks.py): assert values for attempted retries and max retries are as expected

* fix(utils.py): return timeout in litellm proxy response headers

* test(test_fallbacks.py): add test to assert model specific timeout used on timeout error

* test: add bad model with timeout to proxy

* fix: fix linting error

* fix(router.py): fix get model list from model alias

* test: loosen test restriction - account for other events on proxy
This commit is contained in:
Krish Dholakia 2025-01-22 22:19:44 -08:00 committed by GitHub
parent a1a81fc0f3
commit b286bab075
9 changed files with 245 additions and 31 deletions

View file

@ -1,10 +1,11 @@
model_list:
- model_name: openai-gpt-4o
- model_name: gpt-3.5-turbo-end-user-test
litellm_params:
model: openai/my-fake-openai-endpoint
api_key: sk-1234
api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: openai-o1
model: gpt-3.5-turbo
region_name: "eu"
model_info:
id: "1"
- model_name: gpt-3.5-turbo-end-user-test
litellm_params:
model: openai/random_sleep
api_base: http://0.0.0.0:8090

View file

@ -787,9 +787,10 @@ def get_custom_headers(
hidden_params: Optional[dict] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
**kwargs,
) -> dict:
exclude_values = {"", None}
exclude_values = {"", None, "None"}
hidden_params = hidden_params or {}
headers = {
"x-litellm-call-id": call_id,
@ -812,6 +813,7 @@ def get_custom_headers(
if fastest_response_batch_completion is not None
else None
),
"x-litellm-timeout": str(timeout) if timeout is not None else None,
**{k: str(v) for k, v in kwargs.items()},
}
if request_data:
@ -3638,14 +3640,28 @@ async def chat_completion( # noqa: PLR0915
litellm_debug_info,
)
timeout = getattr(
e, "timeout", None
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
timeout=timeout,
)
headers = getattr(e, "headers", {}) or {}
headers.update(custom_headers)
if isinstance(e, HTTPException):
# print("e.headers={}".format(e.headers))
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
headers=getattr(e, "headers", {}),
headers=headers,
)
error_msg = f"{str(e)}"
raise ProxyException(
@ -3653,7 +3669,7 @@ async def chat_completion( # noqa: PLR0915
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
headers=getattr(e, "headers", {}),
headers=headers,
)

View file

@ -57,6 +57,7 @@ from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_strategy.simple_shuffle import simple_shuffle
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
from litellm.router_utils.add_retry_headers import add_retry_headers_to_response
from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name,
replace_model_in_jsonl,
@ -3090,12 +3091,15 @@ class Router:
)
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await self.make_call(original_function, *args, **kwargs)
response = add_retry_headers_to_response(
response=response, attempted_retries=0, max_retries=None
)
return response
except Exception as e:
current_attempt = None
original_exception = e
deployment_num_retries = getattr(e, "num_retries", None)
if deployment_num_retries is not None and isinstance(
deployment_num_retries, int
):
@ -3156,6 +3160,12 @@ class Router:
response
): # async errors are often returned as coroutines
response = await response
response = add_retry_headers_to_response(
response=response,
attempted_retries=current_attempt + 1,
max_retries=num_retries,
)
return response
except Exception as e:
@ -3214,6 +3224,15 @@ class Router:
mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
"mock_testing_rate_limit_error", None
)
available_models = self.get_model_list(model_name=model_group)
num_retries: Optional[int] = None
if available_models is not None and len(available_models) == 1:
num_retries = cast(
Optional[int], available_models[0]["litellm_params"].get("num_retries")
)
if (
mock_testing_rate_limit_error is not None
and mock_testing_rate_limit_error is True
@ -3225,6 +3244,7 @@ class Router:
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
num_retries=num_retries,
)
def should_retry_this_error(
@ -4776,6 +4796,37 @@ class Router:
model_names.append(m["model_name"])
return model_names
def get_model_list_from_model_alias(
self, model_name: Optional[str] = None
) -> List[DeploymentTypedDict]:
"""
Helper function to get model list from model alias.
Used by `.get_model_list` to get model list from model alias.
"""
returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items():
if model_name is not None and model_alias != model_name:
continue
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=_router_model_name, model_alias=model_alias
)
)
return returned_models
def get_model_list(
self, model_name: Optional[str] = None
) -> Optional[List[DeploymentTypedDict]]:
@ -4789,24 +4840,9 @@ class Router:
returned_models.extend(self._get_all_deployments(model_name=model_name))
if hasattr(self, "model_group_alias"):
for model_alias, model_value in self.model_group_alias.items():
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=_router_model_name, model_alias=model_alias
)
)
returned_models.extend(
self.get_model_list_from_model_alias(model_name=model_name)
)
if len(returned_models) == 0: # check if wildcard route
potential_wildcard_models = self.pattern_router.route(model_name)

View file

@ -0,0 +1,40 @@
from typing import Any, Optional, Union
from pydantic import BaseModel
from litellm.types.utils import HiddenParams
def add_retry_headers_to_response(
response: Any,
attempted_retries: int,
max_retries: Optional[int] = None,
) -> Any:
"""
Add retry headers to the request
"""
if response is None or not isinstance(response, BaseModel):
return response
retry_headers = {
"x-litellm-attempted-retries": attempted_retries,
}
if max_retries is not None:
retry_headers["x-litellm-max-retries"] = max_retries
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
response, "_hidden_params", {}
)
if hidden_params is None:
hidden_params = {}
elif isinstance(hidden_params, HiddenParams):
hidden_params = hidden_params.model_dump()
hidden_params.setdefault("additional_headers", {})
hidden_params["additional_headers"].update(retry_headers)
setattr(response, "_hidden_params", hidden_params)
return response

View file

@ -352,6 +352,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
output_cost_per_token: Optional[float]
input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float]
num_retries: Optional[int]
## MOCK RESPONSES ##
mock_response: Optional[Union[str, ModelResponse, Exception]]

View file

@ -669,6 +669,7 @@ def _get_wrapper_num_retries(
Get the number of retries from the kwargs and the retry policy.
Used for the wrapper functions.
"""
num_retries = kwargs.get("num_retries", None)
if num_retries is None:
num_retries = litellm.num_retries
@ -684,6 +685,21 @@ def _get_wrapper_num_retries(
return num_retries, kwargs
def _get_wrapper_timeout(
kwargs: Dict[str, Any], exception: Exception
) -> Optional[Union[float, int, httpx.Timeout]]:
"""
Get the timeout from the kwargs
Used for the wrapper functions.
"""
timeout = cast(
Optional[Union[float, int, httpx.Timeout]], kwargs.get("timeout", None)
)
return timeout
def client(original_function): # noqa: PLR0915
rules_obj = Rules()
@ -1243,9 +1259,11 @@ def client(original_function): # noqa: PLR0915
_is_litellm_router_call = "model_group" in kwargs.get(
"metadata", {}
) # check if call from litellm.router/proxy
if (
num_retries and not _is_litellm_router_call
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
try:
litellm.num_retries = (
None # set retries to None to prevent infinite loops
@ -1266,6 +1284,7 @@ def client(original_function): # noqa: PLR0915
and context_window_fallback_dict
and model in context_window_fallback_dict
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model] # type: ignore
else:
@ -1275,6 +1294,9 @@ def client(original_function): # noqa: PLR0915
setattr(
e, "num_retries", num_retries
) ## IMPORTANT: returns the deployment's num_retries to the router
timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e)
setattr(e, "timeout", timeout)
raise e
is_coroutine = inspect.iscoroutinefunction(original_function)

View file

@ -74,6 +74,12 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
rpm: 1000
- model_name: fake-openai-endpoint-4
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
num_retries: 50
- model_name: fake-openai-endpoint-3
litellm_params:
model: openai/my-fake-model-2
@ -112,6 +118,12 @@ model_list:
- model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model
litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct
- model_name: fake-openai-endpoint-5
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
timeout: 1
litellm_settings:
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
drop_params: True

View file

@ -2742,3 +2742,22 @@ def test_router_prompt_management_factory():
)
print(response)
def test_router_get_model_list_from_model_alias():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
model_group_alias={
"my-special-fake-model-alias-name": "fake-openai-endpoint-3"
},
)
model_alias_list = router.get_model_list_from_model_alias(
model_name="gpt-3.5-turbo"
)
assert len(model_alias_list) == 0

View file

@ -4,6 +4,7 @@ import pytest
import asyncio
import aiohttp
from large_text import text
import time
async def generate_key(
@ -37,7 +38,14 @@ async def generate_key(
return await response.json()
async def chat_completion(session, key: str, model: str, messages: list, **kwargs):
async def chat_completion(
session,
key: str,
model: str,
messages: list,
return_headers: bool = False,
**kwargs,
):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
@ -53,8 +61,15 @@ async def chat_completion(session, key: str, model: str, messages: list, **kwarg
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
if return_headers:
return None, response.headers
else:
raise Exception(f"Request did not return a 200 status code: {status}")
if return_headers:
return await response.json(), response.headers
else:
return await response.json()
@pytest.mark.asyncio
@ -113,6 +128,58 @@ async def test_chat_completion_client_fallbacks(has_access):
pytest.fail("Expected this to work: {}".format(str(e)))
@pytest.mark.asyncio
async def test_chat_completion_with_retries():
"""
make chat completion call with prompt > context window. expect it to work with fallback
"""
async with aiohttp.ClientSession() as session:
model = "fake-openai-endpoint-4"
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
]
response, headers = await chat_completion(
session=session,
key="sk-1234",
model=model,
messages=messages,
mock_testing_rate_limit_error=True,
return_headers=True,
)
print(f"headers: {headers}")
assert headers["x-litellm-attempted-retries"] == "1"
assert headers["x-litellm-max-retries"] == "50"
@pytest.mark.asyncio
async def test_chat_completion_with_timeout():
"""
make chat completion call with low timeout and `mock_timeout`: true. Expect it to fail and correct timeout to be set in headers.
"""
async with aiohttp.ClientSession() as session:
model = "fake-openai-endpoint-5"
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
]
start_time = time.time()
response, headers = await chat_completion(
session=session,
key="sk-1234",
model=model,
messages=messages,
num_retries=0,
mock_timeout=True,
return_headers=True,
)
end_time = time.time()
print(f"headers: {headers}")
assert (
headers["x-litellm-timeout"] == "1.0"
) # assert model-specific timeout used
@pytest.mark.parametrize("has_access", [True, False])
@pytest.mark.asyncio
async def test_chat_completion_client_fallbacks_with_custom_message(has_access):