Add attempted-retries and timeout values to response headers + more testing (#7926)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s

* feat(router.py): add retry headers to response

makes it easy to add testing to ensure model-specific retries are respected

* fix(add_retry_headers.py): clarify attempted retries vs. max retries

* test(test_fallbacks.py): add test for checking if max retries set for model is respected

* test(test_fallbacks.py): assert values for attempted retries and max retries are as expected

* fix(utils.py): return timeout in litellm proxy response headers

* test(test_fallbacks.py): add test to assert model specific timeout used on timeout error

* test: add bad model with timeout to proxy

* fix: fix linting error

* fix(router.py): fix get model list from model alias

* test: loosen test restriction - account for other events on proxy
This commit is contained in:
Krish Dholakia 2025-01-22 22:19:44 -08:00 committed by GitHub
parent bc546d82a1
commit 513b1904ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 245 additions and 31 deletions

View file

@ -1,10 +1,11 @@
model_list: model_list:
- model_name: openai-gpt-4o - model_name: gpt-3.5-turbo-end-user-test
litellm_params: litellm_params:
model: openai/my-fake-openai-endpoint model: gpt-3.5-turbo
api_key: sk-1234 region_name: "eu"
api_base: https://exampleopenaiendpoint-production.up.railway.app model_info:
- model_name: openai-o1 id: "1"
- model_name: gpt-3.5-turbo-end-user-test
litellm_params: litellm_params:
model: openai/random_sleep model: openai/random_sleep
api_base: http://0.0.0.0:8090 api_base: http://0.0.0.0:8090

View file

@ -787,9 +787,10 @@ def get_custom_headers(
hidden_params: Optional[dict] = None, hidden_params: Optional[dict] = None,
fastest_response_batch_completion: Optional[bool] = None, fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {}, request_data: Optional[dict] = {},
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
exclude_values = {"", None} exclude_values = {"", None, "None"}
hidden_params = hidden_params or {} hidden_params = hidden_params or {}
headers = { headers = {
"x-litellm-call-id": call_id, "x-litellm-call-id": call_id,
@ -812,6 +813,7 @@ def get_custom_headers(
if fastest_response_batch_completion is not None if fastest_response_batch_completion is not None
else None else None
), ),
"x-litellm-timeout": str(timeout) if timeout is not None else None,
**{k: str(v) for k, v in kwargs.items()}, **{k: str(v) for k, v in kwargs.items()},
} }
if request_data: if request_data:
@ -3638,14 +3640,28 @@ async def chat_completion( # noqa: PLR0915
litellm_debug_info, litellm_debug_info,
) )
timeout = getattr(
e, "timeout", None
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
timeout=timeout,
)
headers = getattr(e, "headers", {}) or {}
headers.update(custom_headers)
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
# print("e.headers={}".format(e.headers))
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", str(e)), message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"), type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"), param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
headers=getattr(e, "headers", {}), headers=headers,
) )
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
@ -3653,7 +3669,7 @@ async def chat_completion( # noqa: PLR0915
type=getattr(e, "type", "None"), type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"), param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500), code=getattr(e, "status_code", 500),
headers=getattr(e, "headers", {}), headers=headers,
) )

View file

@ -57,6 +57,7 @@ from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_strategy.simple_shuffle import simple_shuffle from litellm.router_strategy.simple_shuffle import simple_shuffle
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
from litellm.router_utils.add_retry_headers import add_retry_headers_to_response
from litellm.router_utils.batch_utils import ( from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name, _get_router_metadata_variable_name,
replace_model_in_jsonl, replace_model_in_jsonl,
@ -3090,12 +3091,15 @@ class Router:
) )
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
response = await self.make_call(original_function, *args, **kwargs) response = await self.make_call(original_function, *args, **kwargs)
response = add_retry_headers_to_response(
response=response, attempted_retries=0, max_retries=None
)
return response return response
except Exception as e: except Exception as e:
current_attempt = None current_attempt = None
original_exception = e original_exception = e
deployment_num_retries = getattr(e, "num_retries", None) deployment_num_retries = getattr(e, "num_retries", None)
if deployment_num_retries is not None and isinstance( if deployment_num_retries is not None and isinstance(
deployment_num_retries, int deployment_num_retries, int
): ):
@ -3156,6 +3160,12 @@ class Router:
response response
): # async errors are often returned as coroutines ): # async errors are often returned as coroutines
response = await response response = await response
response = add_retry_headers_to_response(
response=response,
attempted_retries=current_attempt + 1,
max_retries=num_retries,
)
return response return response
except Exception as e: except Exception as e:
@ -3214,6 +3224,15 @@ class Router:
mock_testing_rate_limit_error: Optional[bool] = kwargs.pop( mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
"mock_testing_rate_limit_error", None "mock_testing_rate_limit_error", None
) )
available_models = self.get_model_list(model_name=model_group)
num_retries: Optional[int] = None
if available_models is not None and len(available_models) == 1:
num_retries = cast(
Optional[int], available_models[0]["litellm_params"].get("num_retries")
)
if ( if (
mock_testing_rate_limit_error is not None mock_testing_rate_limit_error is not None
and mock_testing_rate_limit_error is True and mock_testing_rate_limit_error is True
@ -3225,6 +3244,7 @@ class Router:
model=model_group, model=model_group,
llm_provider="", llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
num_retries=num_retries,
) )
def should_retry_this_error( def should_retry_this_error(
@ -4776,6 +4796,37 @@ class Router:
model_names.append(m["model_name"]) model_names.append(m["model_name"])
return model_names return model_names
def get_model_list_from_model_alias(
self, model_name: Optional[str] = None
) -> List[DeploymentTypedDict]:
"""
Helper function to get model list from model alias.
Used by `.get_model_list` to get model list from model alias.
"""
returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items():
if model_name is not None and model_alias != model_name:
continue
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=_router_model_name, model_alias=model_alias
)
)
return returned_models
def get_model_list( def get_model_list(
self, model_name: Optional[str] = None self, model_name: Optional[str] = None
) -> Optional[List[DeploymentTypedDict]]: ) -> Optional[List[DeploymentTypedDict]]:
@ -4789,24 +4840,9 @@ class Router:
returned_models.extend(self._get_all_deployments(model_name=model_name)) returned_models.extend(self._get_all_deployments(model_name=model_name))
if hasattr(self, "model_group_alias"): if hasattr(self, "model_group_alias"):
for model_alias, model_value in self.model_group_alias.items(): returned_models.extend(
self.get_model_list_from_model_alias(model_name=model_name)
if isinstance(model_value, str): )
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=_router_model_name, model_alias=model_alias
)
)
if len(returned_models) == 0: # check if wildcard route if len(returned_models) == 0: # check if wildcard route
potential_wildcard_models = self.pattern_router.route(model_name) potential_wildcard_models = self.pattern_router.route(model_name)

View file

@ -0,0 +1,40 @@
from typing import Any, Optional, Union
from pydantic import BaseModel
from litellm.types.utils import HiddenParams
def add_retry_headers_to_response(
response: Any,
attempted_retries: int,
max_retries: Optional[int] = None,
) -> Any:
"""
Add retry headers to the request
"""
if response is None or not isinstance(response, BaseModel):
return response
retry_headers = {
"x-litellm-attempted-retries": attempted_retries,
}
if max_retries is not None:
retry_headers["x-litellm-max-retries"] = max_retries
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
response, "_hidden_params", {}
)
if hidden_params is None:
hidden_params = {}
elif isinstance(hidden_params, HiddenParams):
hidden_params = hidden_params.model_dump()
hidden_params.setdefault("additional_headers", {})
hidden_params["additional_headers"].update(retry_headers)
setattr(response, "_hidden_params", hidden_params)
return response

View file

@ -352,6 +352,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
output_cost_per_token: Optional[float] output_cost_per_token: Optional[float]
input_cost_per_second: Optional[float] input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float] output_cost_per_second: Optional[float]
num_retries: Optional[int]
## MOCK RESPONSES ## ## MOCK RESPONSES ##
mock_response: Optional[Union[str, ModelResponse, Exception]] mock_response: Optional[Union[str, ModelResponse, Exception]]

View file

@ -669,6 +669,7 @@ def _get_wrapper_num_retries(
Get the number of retries from the kwargs and the retry policy. Get the number of retries from the kwargs and the retry policy.
Used for the wrapper functions. Used for the wrapper functions.
""" """
num_retries = kwargs.get("num_retries", None) num_retries = kwargs.get("num_retries", None)
if num_retries is None: if num_retries is None:
num_retries = litellm.num_retries num_retries = litellm.num_retries
@ -684,6 +685,21 @@ def _get_wrapper_num_retries(
return num_retries, kwargs return num_retries, kwargs
def _get_wrapper_timeout(
kwargs: Dict[str, Any], exception: Exception
) -> Optional[Union[float, int, httpx.Timeout]]:
"""
Get the timeout from the kwargs
Used for the wrapper functions.
"""
timeout = cast(
Optional[Union[float, int, httpx.Timeout]], kwargs.get("timeout", None)
)
return timeout
def client(original_function): # noqa: PLR0915 def client(original_function): # noqa: PLR0915
rules_obj = Rules() rules_obj = Rules()
@ -1243,9 +1259,11 @@ def client(original_function): # noqa: PLR0915
_is_litellm_router_call = "model_group" in kwargs.get( _is_litellm_router_call = "model_group" in kwargs.get(
"metadata", {} "metadata", {}
) # check if call from litellm.router/proxy ) # check if call from litellm.router/proxy
if ( if (
num_retries and not _is_litellm_router_call num_retries and not _is_litellm_router_call
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
try: try:
litellm.num_retries = ( litellm.num_retries = (
None # set retries to None to prevent infinite loops None # set retries to None to prevent infinite loops
@ -1266,6 +1284,7 @@ def client(original_function): # noqa: PLR0915
and context_window_fallback_dict and context_window_fallback_dict
and model in context_window_fallback_dict and model in context_window_fallback_dict
): ):
if len(args) > 0: if len(args) > 0:
args[0] = context_window_fallback_dict[model] # type: ignore args[0] = context_window_fallback_dict[model] # type: ignore
else: else:
@ -1275,6 +1294,9 @@ def client(original_function): # noqa: PLR0915
setattr( setattr(
e, "num_retries", num_retries e, "num_retries", num_retries
) ## IMPORTANT: returns the deployment's num_retries to the router ) ## IMPORTANT: returns the deployment's num_retries to the router
timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e)
setattr(e, "timeout", timeout)
raise e raise e
is_coroutine = inspect.iscoroutinefunction(original_function) is_coroutine = inspect.iscoroutinefunction(original_function)

View file

@ -74,6 +74,12 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001 stream_timeout: 0.001
rpm: 1000 rpm: 1000
- model_name: fake-openai-endpoint-4
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
num_retries: 50
- model_name: fake-openai-endpoint-3 - model_name: fake-openai-endpoint-3
litellm_params: litellm_params:
model: openai/my-fake-model-2 model: openai/my-fake-model-2
@ -112,6 +118,12 @@ model_list:
- model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model - model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model
litellm_params: litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct model: text-completion-openai/gpt-3.5-turbo-instruct
- model_name: fake-openai-endpoint-5
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
timeout: 1
litellm_settings: litellm_settings:
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production # set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
drop_params: True drop_params: True

View file

@ -2742,3 +2742,22 @@ def test_router_prompt_management_factory():
) )
print(response) print(response)
def test_router_get_model_list_from_model_alias():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
model_group_alias={
"my-special-fake-model-alias-name": "fake-openai-endpoint-3"
},
)
model_alias_list = router.get_model_list_from_model_alias(
model_name="gpt-3.5-turbo"
)
assert len(model_alias_list) == 0

View file

@ -4,6 +4,7 @@ import pytest
import asyncio import asyncio
import aiohttp import aiohttp
from large_text import text from large_text import text
import time
async def generate_key( async def generate_key(
@ -37,7 +38,14 @@ async def generate_key(
return await response.json() return await response.json()
async def chat_completion(session, key: str, model: str, messages: list, **kwargs): async def chat_completion(
session,
key: str,
model: str,
messages: list,
return_headers: bool = False,
**kwargs,
):
url = "http://0.0.0.0:4000/chat/completions" url = "http://0.0.0.0:4000/chat/completions"
headers = { headers = {
"Authorization": f"Bearer {key}", "Authorization": f"Bearer {key}",
@ -53,8 +61,15 @@ async def chat_completion(session, key: str, model: str, messages: list, **kwarg
print() print()
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") if return_headers:
return await response.json() return None, response.headers
else:
raise Exception(f"Request did not return a 200 status code: {status}")
if return_headers:
return await response.json(), response.headers
else:
return await response.json()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -113,6 +128,58 @@ async def test_chat_completion_client_fallbacks(has_access):
pytest.fail("Expected this to work: {}".format(str(e))) pytest.fail("Expected this to work: {}".format(str(e)))
@pytest.mark.asyncio
async def test_chat_completion_with_retries():
"""
make chat completion call with prompt > context window. expect it to work with fallback
"""
async with aiohttp.ClientSession() as session:
model = "fake-openai-endpoint-4"
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
]
response, headers = await chat_completion(
session=session,
key="sk-1234",
model=model,
messages=messages,
mock_testing_rate_limit_error=True,
return_headers=True,
)
print(f"headers: {headers}")
assert headers["x-litellm-attempted-retries"] == "1"
assert headers["x-litellm-max-retries"] == "50"
@pytest.mark.asyncio
async def test_chat_completion_with_timeout():
"""
make chat completion call with low timeout and `mock_timeout`: true. Expect it to fail and correct timeout to be set in headers.
"""
async with aiohttp.ClientSession() as session:
model = "fake-openai-endpoint-5"
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
]
start_time = time.time()
response, headers = await chat_completion(
session=session,
key="sk-1234",
model=model,
messages=messages,
num_retries=0,
mock_timeout=True,
return_headers=True,
)
end_time = time.time()
print(f"headers: {headers}")
assert (
headers["x-litellm-timeout"] == "1.0"
) # assert model-specific timeout used
@pytest.mark.parametrize("has_access", [True, False]) @pytest.mark.parametrize("has_access", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_client_fallbacks_with_custom_message(has_access): async def test_chat_completion_client_fallbacks_with_custom_message(has_access):