mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Add attempted-retries
and timeout
values to response headers + more testing (#7926)
* feat(router.py): add retry headers to response makes it easy to add testing to ensure model-specific retries are respected * fix(add_retry_headers.py): clarify attempted retries vs. max retries * test(test_fallbacks.py): add test for checking if max retries set for model is respected * test(test_fallbacks.py): assert values for attempted retries and max retries are as expected * fix(utils.py): return timeout in litellm proxy response headers * test(test_fallbacks.py): add test to assert model specific timeout used on timeout error * test: add bad model with timeout to proxy * fix: fix linting error * fix(router.py): fix get model list from model alias * test: loosen test restriction - account for other events on proxy
This commit is contained in:
parent
a1a81fc0f3
commit
b286bab075
9 changed files with 245 additions and 31 deletions
|
@ -1,10 +1,11 @@
|
|||
model_list:
|
||||
- model_name: openai-gpt-4o
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
litellm_params:
|
||||
model: openai/my-fake-openai-endpoint
|
||||
api_key: sk-1234
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
- model_name: openai-o1
|
||||
model: gpt-3.5-turbo
|
||||
region_name: "eu"
|
||||
model_info:
|
||||
id: "1"
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
litellm_params:
|
||||
model: openai/random_sleep
|
||||
api_base: http://0.0.0.0:8090
|
||||
|
|
|
@ -787,9 +787,10 @@ def get_custom_headers(
|
|||
hidden_params: Optional[dict] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
request_data: Optional[dict] = {},
|
||||
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
exclude_values = {"", None}
|
||||
exclude_values = {"", None, "None"}
|
||||
hidden_params = hidden_params or {}
|
||||
headers = {
|
||||
"x-litellm-call-id": call_id,
|
||||
|
@ -812,6 +813,7 @@ def get_custom_headers(
|
|||
if fastest_response_batch_completion is not None
|
||||
else None
|
||||
),
|
||||
"x-litellm-timeout": str(timeout) if timeout is not None else None,
|
||||
**{k: str(v) for k, v in kwargs.items()},
|
||||
}
|
||||
if request_data:
|
||||
|
@ -3638,14 +3640,28 @@ async def chat_completion( # noqa: PLR0915
|
|||
litellm_debug_info,
|
||||
)
|
||||
|
||||
timeout = getattr(
|
||||
e, "timeout", None
|
||||
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
|
||||
|
||||
custom_headers = get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
response_cost=0,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
headers = getattr(e, "headers", {}) or {}
|
||||
headers.update(custom_headers)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
# print("e.headers={}".format(e.headers))
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
headers=getattr(e, "headers", {}),
|
||||
headers=headers,
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
|
@ -3653,7 +3669,7 @@ async def chat_completion( # noqa: PLR0915
|
|||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
headers=getattr(e, "headers", {}),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -57,6 +57,7 @@ from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
|||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.router_strategy.simple_shuffle import simple_shuffle
|
||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||
from litellm.router_utils.add_retry_headers import add_retry_headers_to_response
|
||||
from litellm.router_utils.batch_utils import (
|
||||
_get_router_metadata_variable_name,
|
||||
replace_model_in_jsonl,
|
||||
|
@ -3090,12 +3091,15 @@ class Router:
|
|||
)
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = await self.make_call(original_function, *args, **kwargs)
|
||||
|
||||
response = add_retry_headers_to_response(
|
||||
response=response, attempted_retries=0, max_retries=None
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
deployment_num_retries = getattr(e, "num_retries", None)
|
||||
|
||||
if deployment_num_retries is not None and isinstance(
|
||||
deployment_num_retries, int
|
||||
):
|
||||
|
@ -3156,6 +3160,12 @@ class Router:
|
|||
response
|
||||
): # async errors are often returned as coroutines
|
||||
response = await response
|
||||
|
||||
response = add_retry_headers_to_response(
|
||||
response=response,
|
||||
attempted_retries=current_attempt + 1,
|
||||
max_retries=num_retries,
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
@ -3214,6 +3224,15 @@ class Router:
|
|||
mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
|
||||
"mock_testing_rate_limit_error", None
|
||||
)
|
||||
|
||||
available_models = self.get_model_list(model_name=model_group)
|
||||
num_retries: Optional[int] = None
|
||||
|
||||
if available_models is not None and len(available_models) == 1:
|
||||
num_retries = cast(
|
||||
Optional[int], available_models[0]["litellm_params"].get("num_retries")
|
||||
)
|
||||
|
||||
if (
|
||||
mock_testing_rate_limit_error is not None
|
||||
and mock_testing_rate_limit_error is True
|
||||
|
@ -3225,6 +3244,7 @@ class Router:
|
|||
model=model_group,
|
||||
llm_provider="",
|
||||
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
||||
num_retries=num_retries,
|
||||
)
|
||||
|
||||
def should_retry_this_error(
|
||||
|
@ -4776,6 +4796,37 @@ class Router:
|
|||
model_names.append(m["model_name"])
|
||||
return model_names
|
||||
|
||||
def get_model_list_from_model_alias(
|
||||
self, model_name: Optional[str] = None
|
||||
) -> List[DeploymentTypedDict]:
|
||||
"""
|
||||
Helper function to get model list from model alias.
|
||||
|
||||
Used by `.get_model_list` to get model list from model alias.
|
||||
"""
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
if model_name is not None and model_alias != model_name:
|
||||
continue
|
||||
if isinstance(model_value, str):
|
||||
_router_model_name: str = model_value
|
||||
elif isinstance(model_value, dict):
|
||||
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
||||
if _model_value["hidden"] is True:
|
||||
continue
|
||||
else:
|
||||
_router_model_name = _model_value["model"]
|
||||
else:
|
||||
continue
|
||||
|
||||
returned_models.extend(
|
||||
self._get_all_deployments(
|
||||
model_name=_router_model_name, model_alias=model_alias
|
||||
)
|
||||
)
|
||||
|
||||
return returned_models
|
||||
|
||||
def get_model_list(
|
||||
self, model_name: Optional[str] = None
|
||||
) -> Optional[List[DeploymentTypedDict]]:
|
||||
|
@ -4789,24 +4840,9 @@ class Router:
|
|||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
|
||||
if hasattr(self, "model_group_alias"):
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
if isinstance(model_value, str):
|
||||
_router_model_name: str = model_value
|
||||
elif isinstance(model_value, dict):
|
||||
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
||||
if _model_value["hidden"] is True:
|
||||
continue
|
||||
else:
|
||||
_router_model_name = _model_value["model"]
|
||||
else:
|
||||
continue
|
||||
|
||||
returned_models.extend(
|
||||
self._get_all_deployments(
|
||||
model_name=_router_model_name, model_alias=model_alias
|
||||
)
|
||||
)
|
||||
returned_models.extend(
|
||||
self.get_model_list_from_model_alias(model_name=model_name)
|
||||
)
|
||||
|
||||
if len(returned_models) == 0: # check if wildcard route
|
||||
potential_wildcard_models = self.pattern_router.route(model_name)
|
||||
|
|
40
litellm/router_utils/add_retry_headers.py
Normal file
40
litellm/router_utils/add_retry_headers.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.utils import HiddenParams
|
||||
|
||||
|
||||
def add_retry_headers_to_response(
|
||||
response: Any,
|
||||
attempted_retries: int,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Add retry headers to the request
|
||||
"""
|
||||
|
||||
if response is None or not isinstance(response, BaseModel):
|
||||
return response
|
||||
|
||||
retry_headers = {
|
||||
"x-litellm-attempted-retries": attempted_retries,
|
||||
}
|
||||
if max_retries is not None:
|
||||
retry_headers["x-litellm-max-retries"] = max_retries
|
||||
|
||||
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
|
||||
response, "_hidden_params", {}
|
||||
)
|
||||
|
||||
if hidden_params is None:
|
||||
hidden_params = {}
|
||||
elif isinstance(hidden_params, HiddenParams):
|
||||
hidden_params = hidden_params.model_dump()
|
||||
|
||||
hidden_params.setdefault("additional_headers", {})
|
||||
hidden_params["additional_headers"].update(retry_headers)
|
||||
|
||||
setattr(response, "_hidden_params", hidden_params)
|
||||
|
||||
return response
|
|
@ -352,6 +352,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
output_cost_per_token: Optional[float]
|
||||
input_cost_per_second: Optional[float]
|
||||
output_cost_per_second: Optional[float]
|
||||
num_retries: Optional[int]
|
||||
## MOCK RESPONSES ##
|
||||
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
||||
|
||||
|
|
|
@ -669,6 +669,7 @@ def _get_wrapper_num_retries(
|
|||
Get the number of retries from the kwargs and the retry policy.
|
||||
Used for the wrapper functions.
|
||||
"""
|
||||
|
||||
num_retries = kwargs.get("num_retries", None)
|
||||
if num_retries is None:
|
||||
num_retries = litellm.num_retries
|
||||
|
@ -684,6 +685,21 @@ def _get_wrapper_num_retries(
|
|||
return num_retries, kwargs
|
||||
|
||||
|
||||
def _get_wrapper_timeout(
|
||||
kwargs: Dict[str, Any], exception: Exception
|
||||
) -> Optional[Union[float, int, httpx.Timeout]]:
|
||||
"""
|
||||
Get the timeout from the kwargs
|
||||
Used for the wrapper functions.
|
||||
"""
|
||||
|
||||
timeout = cast(
|
||||
Optional[Union[float, int, httpx.Timeout]], kwargs.get("timeout", None)
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
|
||||
def client(original_function): # noqa: PLR0915
|
||||
rules_obj = Rules()
|
||||
|
||||
|
@ -1243,9 +1259,11 @@ def client(original_function): # noqa: PLR0915
|
|||
_is_litellm_router_call = "model_group" in kwargs.get(
|
||||
"metadata", {}
|
||||
) # check if call from litellm.router/proxy
|
||||
|
||||
if (
|
||||
num_retries and not _is_litellm_router_call
|
||||
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
||||
|
||||
try:
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
|
@ -1266,6 +1284,7 @@ def client(original_function): # noqa: PLR0915
|
|||
and context_window_fallback_dict
|
||||
and model in context_window_fallback_dict
|
||||
):
|
||||
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model] # type: ignore
|
||||
else:
|
||||
|
@ -1275,6 +1294,9 @@ def client(original_function): # noqa: PLR0915
|
|||
setattr(
|
||||
e, "num_retries", num_retries
|
||||
) ## IMPORTANT: returns the deployment's num_retries to the router
|
||||
|
||||
timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e)
|
||||
setattr(e, "timeout", timeout)
|
||||
raise e
|
||||
|
||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||
|
|
|
@ -74,6 +74,12 @@ model_list:
|
|||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
stream_timeout: 0.001
|
||||
rpm: 1000
|
||||
- model_name: fake-openai-endpoint-4
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
num_retries: 50
|
||||
- model_name: fake-openai-endpoint-3
|
||||
litellm_params:
|
||||
model: openai/my-fake-model-2
|
||||
|
@ -112,6 +118,12 @@ model_list:
|
|||
- model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model
|
||||
litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||
- model_name: fake-openai-endpoint-5
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
timeout: 1
|
||||
litellm_settings:
|
||||
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
|
||||
drop_params: True
|
||||
|
|
|
@ -2742,3 +2742,22 @@ def test_router_prompt_management_factory():
|
|||
)
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
def test_router_get_model_list_from_model_alias():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
}
|
||||
],
|
||||
model_group_alias={
|
||||
"my-special-fake-model-alias-name": "fake-openai-endpoint-3"
|
||||
},
|
||||
)
|
||||
|
||||
model_alias_list = router.get_model_list_from_model_alias(
|
||||
model_name="gpt-3.5-turbo"
|
||||
)
|
||||
assert len(model_alias_list) == 0
|
||||
|
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
import asyncio
|
||||
import aiohttp
|
||||
from large_text import text
|
||||
import time
|
||||
|
||||
|
||||
async def generate_key(
|
||||
|
@ -37,7 +38,14 @@ async def generate_key(
|
|||
return await response.json()
|
||||
|
||||
|
||||
async def chat_completion(session, key: str, model: str, messages: list, **kwargs):
|
||||
async def chat_completion(
|
||||
session,
|
||||
key: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
return_headers: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
url = "http://0.0.0.0:4000/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
|
@ -53,8 +61,15 @@ async def chat_completion(session, key: str, model: str, messages: list, **kwarg
|
|||
print()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
return await response.json()
|
||||
if return_headers:
|
||||
return None, response.headers
|
||||
else:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
|
||||
if return_headers:
|
||||
return await response.json(), response.headers
|
||||
else:
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -113,6 +128,58 @@ async def test_chat_completion_client_fallbacks(has_access):
|
|||
pytest.fail("Expected this to work: {}".format(str(e)))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_retries():
|
||||
"""
|
||||
make chat completion call with prompt > context window. expect it to work with fallback
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
model = "fake-openai-endpoint-4"
|
||||
messages = [
|
||||
{"role": "system", "content": text},
|
||||
{"role": "user", "content": "Who was Alexander?"},
|
||||
]
|
||||
response, headers = await chat_completion(
|
||||
session=session,
|
||||
key="sk-1234",
|
||||
model=model,
|
||||
messages=messages,
|
||||
mock_testing_rate_limit_error=True,
|
||||
return_headers=True,
|
||||
)
|
||||
print(f"headers: {headers}")
|
||||
assert headers["x-litellm-attempted-retries"] == "1"
|
||||
assert headers["x-litellm-max-retries"] == "50"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_timeout():
|
||||
"""
|
||||
make chat completion call with low timeout and `mock_timeout`: true. Expect it to fail and correct timeout to be set in headers.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
model = "fake-openai-endpoint-5"
|
||||
messages = [
|
||||
{"role": "system", "content": text},
|
||||
{"role": "user", "content": "Who was Alexander?"},
|
||||
]
|
||||
start_time = time.time()
|
||||
response, headers = await chat_completion(
|
||||
session=session,
|
||||
key="sk-1234",
|
||||
model=model,
|
||||
messages=messages,
|
||||
num_retries=0,
|
||||
mock_timeout=True,
|
||||
return_headers=True,
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"headers: {headers}")
|
||||
assert (
|
||||
headers["x-litellm-timeout"] == "1.0"
|
||||
) # assert model-specific timeout used
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_access", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_client_fallbacks_with_custom_message(has_access):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue