mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Add attempted-retries
and timeout
values to response headers + more testing (#7926)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s
* feat(router.py): add retry headers to response makes it easy to add testing to ensure model-specific retries are respected * fix(add_retry_headers.py): clarify attempted retries vs. max retries * test(test_fallbacks.py): add test for checking if max retries set for model is respected * test(test_fallbacks.py): assert values for attempted retries and max retries are as expected * fix(utils.py): return timeout in litellm proxy response headers * test(test_fallbacks.py): add test to assert model specific timeout used on timeout error * test: add bad model with timeout to proxy * fix: fix linting error * fix(router.py): fix get model list from model alias * test: loosen test restriction - account for other events on proxy
This commit is contained in:
parent
bc546d82a1
commit
513b1904ab
9 changed files with 245 additions and 31 deletions
|
@ -1,10 +1,11 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: openai-gpt-4o
|
- model_name: gpt-3.5-turbo-end-user-test
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/my-fake-openai-endpoint
|
model: gpt-3.5-turbo
|
||||||
api_key: sk-1234
|
region_name: "eu"
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
model_info:
|
||||||
- model_name: openai-o1
|
id: "1"
|
||||||
|
- model_name: gpt-3.5-turbo-end-user-test
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/random_sleep
|
model: openai/random_sleep
|
||||||
api_base: http://0.0.0.0:8090
|
api_base: http://0.0.0.0:8090
|
||||||
|
|
|
@ -787,9 +787,10 @@ def get_custom_headers(
|
||||||
hidden_params: Optional[dict] = None,
|
hidden_params: Optional[dict] = None,
|
||||||
fastest_response_batch_completion: Optional[bool] = None,
|
fastest_response_batch_completion: Optional[bool] = None,
|
||||||
request_data: Optional[dict] = {},
|
request_data: Optional[dict] = {},
|
||||||
|
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
exclude_values = {"", None}
|
exclude_values = {"", None, "None"}
|
||||||
hidden_params = hidden_params or {}
|
hidden_params = hidden_params or {}
|
||||||
headers = {
|
headers = {
|
||||||
"x-litellm-call-id": call_id,
|
"x-litellm-call-id": call_id,
|
||||||
|
@ -812,6 +813,7 @@ def get_custom_headers(
|
||||||
if fastest_response_batch_completion is not None
|
if fastest_response_batch_completion is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
"x-litellm-timeout": str(timeout) if timeout is not None else None,
|
||||||
**{k: str(v) for k, v in kwargs.items()},
|
**{k: str(v) for k, v in kwargs.items()},
|
||||||
}
|
}
|
||||||
if request_data:
|
if request_data:
|
||||||
|
@ -3638,14 +3640,28 @@ async def chat_completion( # noqa: PLR0915
|
||||||
litellm_debug_info,
|
litellm_debug_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
timeout = getattr(
|
||||||
|
e, "timeout", None
|
||||||
|
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
|
||||||
|
|
||||||
|
custom_headers = get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
version=version,
|
||||||
|
response_cost=0,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
headers = getattr(e, "headers", {}) or {}
|
||||||
|
headers.update(custom_headers)
|
||||||
|
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
# print("e.headers={}".format(e.headers))
|
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=getattr(e, "detail", str(e)),
|
message=getattr(e, "detail", str(e)),
|
||||||
type=getattr(e, "type", "None"),
|
type=getattr(e, "type", "None"),
|
||||||
param=getattr(e, "param", "None"),
|
param=getattr(e, "param", "None"),
|
||||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
headers=getattr(e, "headers", {}),
|
headers=headers,
|
||||||
)
|
)
|
||||||
error_msg = f"{str(e)}"
|
error_msg = f"{str(e)}"
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -3653,7 +3669,7 @@ async def chat_completion( # noqa: PLR0915
|
||||||
type=getattr(e, "type", "None"),
|
type=getattr(e, "type", "None"),
|
||||||
param=getattr(e, "param", "None"),
|
param=getattr(e, "param", "None"),
|
||||||
code=getattr(e, "status_code", 500),
|
code=getattr(e, "status_code", 500),
|
||||||
headers=getattr(e, "headers", {}),
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,7 @@ from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
from litellm.router_strategy.simple_shuffle import simple_shuffle
|
from litellm.router_strategy.simple_shuffle import simple_shuffle
|
||||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||||
|
from litellm.router_utils.add_retry_headers import add_retry_headers_to_response
|
||||||
from litellm.router_utils.batch_utils import (
|
from litellm.router_utils.batch_utils import (
|
||||||
_get_router_metadata_variable_name,
|
_get_router_metadata_variable_name,
|
||||||
replace_model_in_jsonl,
|
replace_model_in_jsonl,
|
||||||
|
@ -3090,12 +3091,15 @@ class Router:
|
||||||
)
|
)
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = await self.make_call(original_function, *args, **kwargs)
|
response = await self.make_call(original_function, *args, **kwargs)
|
||||||
|
response = add_retry_headers_to_response(
|
||||||
|
response=response, attempted_retries=0, max_retries=None
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_attempt = None
|
current_attempt = None
|
||||||
original_exception = e
|
original_exception = e
|
||||||
deployment_num_retries = getattr(e, "num_retries", None)
|
deployment_num_retries = getattr(e, "num_retries", None)
|
||||||
|
|
||||||
if deployment_num_retries is not None and isinstance(
|
if deployment_num_retries is not None and isinstance(
|
||||||
deployment_num_retries, int
|
deployment_num_retries, int
|
||||||
):
|
):
|
||||||
|
@ -3156,6 +3160,12 @@ class Router:
|
||||||
response
|
response
|
||||||
): # async errors are often returned as coroutines
|
): # async errors are often returned as coroutines
|
||||||
response = await response
|
response = await response
|
||||||
|
|
||||||
|
response = add_retry_headers_to_response(
|
||||||
|
response=response,
|
||||||
|
attempted_retries=current_attempt + 1,
|
||||||
|
max_retries=num_retries,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3214,6 +3224,15 @@ class Router:
|
||||||
mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
|
mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
|
||||||
"mock_testing_rate_limit_error", None
|
"mock_testing_rate_limit_error", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
available_models = self.get_model_list(model_name=model_group)
|
||||||
|
num_retries: Optional[int] = None
|
||||||
|
|
||||||
|
if available_models is not None and len(available_models) == 1:
|
||||||
|
num_retries = cast(
|
||||||
|
Optional[int], available_models[0]["litellm_params"].get("num_retries")
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
mock_testing_rate_limit_error is not None
|
mock_testing_rate_limit_error is not None
|
||||||
and mock_testing_rate_limit_error is True
|
and mock_testing_rate_limit_error is True
|
||||||
|
@ -3225,6 +3244,7 @@ class Router:
|
||||||
model=model_group,
|
model=model_group,
|
||||||
llm_provider="",
|
llm_provider="",
|
||||||
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
||||||
|
num_retries=num_retries,
|
||||||
)
|
)
|
||||||
|
|
||||||
def should_retry_this_error(
|
def should_retry_this_error(
|
||||||
|
@ -4776,6 +4796,37 @@ class Router:
|
||||||
model_names.append(m["model_name"])
|
model_names.append(m["model_name"])
|
||||||
return model_names
|
return model_names
|
||||||
|
|
||||||
|
def get_model_list_from_model_alias(
|
||||||
|
self, model_name: Optional[str] = None
|
||||||
|
) -> List[DeploymentTypedDict]:
|
||||||
|
"""
|
||||||
|
Helper function to get model list from model alias.
|
||||||
|
|
||||||
|
Used by `.get_model_list` to get model list from model alias.
|
||||||
|
"""
|
||||||
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
for model_alias, model_value in self.model_group_alias.items():
|
||||||
|
if model_name is not None and model_alias != model_name:
|
||||||
|
continue
|
||||||
|
if isinstance(model_value, str):
|
||||||
|
_router_model_name: str = model_value
|
||||||
|
elif isinstance(model_value, dict):
|
||||||
|
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
||||||
|
if _model_value["hidden"] is True:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
_router_model_name = _model_value["model"]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
returned_models.extend(
|
||||||
|
self._get_all_deployments(
|
||||||
|
model_name=_router_model_name, model_alias=model_alias
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return returned_models
|
||||||
|
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
self, model_name: Optional[str] = None
|
self, model_name: Optional[str] = None
|
||||||
) -> Optional[List[DeploymentTypedDict]]:
|
) -> Optional[List[DeploymentTypedDict]]:
|
||||||
|
@ -4789,24 +4840,9 @@ class Router:
|
||||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||||
|
|
||||||
if hasattr(self, "model_group_alias"):
|
if hasattr(self, "model_group_alias"):
|
||||||
for model_alias, model_value in self.model_group_alias.items():
|
returned_models.extend(
|
||||||
|
self.get_model_list_from_model_alias(model_name=model_name)
|
||||||
if isinstance(model_value, str):
|
)
|
||||||
_router_model_name: str = model_value
|
|
||||||
elif isinstance(model_value, dict):
|
|
||||||
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
|
||||||
if _model_value["hidden"] is True:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
_router_model_name = _model_value["model"]
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
returned_models.extend(
|
|
||||||
self._get_all_deployments(
|
|
||||||
model_name=_router_model_name, model_alias=model_alias
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(returned_models) == 0: # check if wildcard route
|
if len(returned_models) == 0: # check if wildcard route
|
||||||
potential_wildcard_models = self.pattern_router.route(model_name)
|
potential_wildcard_models = self.pattern_router.route(model_name)
|
||||||
|
|
40
litellm/router_utils/add_retry_headers.py
Normal file
40
litellm/router_utils/add_retry_headers.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm.types.utils import HiddenParams
|
||||||
|
|
||||||
|
|
||||||
|
def add_retry_headers_to_response(
|
||||||
|
response: Any,
|
||||||
|
attempted_retries: int,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Add retry headers to the request
|
||||||
|
"""
|
||||||
|
|
||||||
|
if response is None or not isinstance(response, BaseModel):
|
||||||
|
return response
|
||||||
|
|
||||||
|
retry_headers = {
|
||||||
|
"x-litellm-attempted-retries": attempted_retries,
|
||||||
|
}
|
||||||
|
if max_retries is not None:
|
||||||
|
retry_headers["x-litellm-max-retries"] = max_retries
|
||||||
|
|
||||||
|
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
|
||||||
|
response, "_hidden_params", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if hidden_params is None:
|
||||||
|
hidden_params = {}
|
||||||
|
elif isinstance(hidden_params, HiddenParams):
|
||||||
|
hidden_params = hidden_params.model_dump()
|
||||||
|
|
||||||
|
hidden_params.setdefault("additional_headers", {})
|
||||||
|
hidden_params["additional_headers"].update(retry_headers)
|
||||||
|
|
||||||
|
setattr(response, "_hidden_params", hidden_params)
|
||||||
|
|
||||||
|
return response
|
|
@ -352,6 +352,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
output_cost_per_token: Optional[float]
|
output_cost_per_token: Optional[float]
|
||||||
input_cost_per_second: Optional[float]
|
input_cost_per_second: Optional[float]
|
||||||
output_cost_per_second: Optional[float]
|
output_cost_per_second: Optional[float]
|
||||||
|
num_retries: Optional[int]
|
||||||
## MOCK RESPONSES ##
|
## MOCK RESPONSES ##
|
||||||
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
||||||
|
|
||||||
|
|
|
@ -669,6 +669,7 @@ def _get_wrapper_num_retries(
|
||||||
Get the number of retries from the kwargs and the retry policy.
|
Get the number of retries from the kwargs and the retry policy.
|
||||||
Used for the wrapper functions.
|
Used for the wrapper functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_retries = kwargs.get("num_retries", None)
|
num_retries = kwargs.get("num_retries", None)
|
||||||
if num_retries is None:
|
if num_retries is None:
|
||||||
num_retries = litellm.num_retries
|
num_retries = litellm.num_retries
|
||||||
|
@ -684,6 +685,21 @@ def _get_wrapper_num_retries(
|
||||||
return num_retries, kwargs
|
return num_retries, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_wrapper_timeout(
|
||||||
|
kwargs: Dict[str, Any], exception: Exception
|
||||||
|
) -> Optional[Union[float, int, httpx.Timeout]]:
|
||||||
|
"""
|
||||||
|
Get the timeout from the kwargs
|
||||||
|
Used for the wrapper functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timeout = cast(
|
||||||
|
Optional[Union[float, int, httpx.Timeout]], kwargs.get("timeout", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
return timeout
|
||||||
|
|
||||||
|
|
||||||
def client(original_function): # noqa: PLR0915
|
def client(original_function): # noqa: PLR0915
|
||||||
rules_obj = Rules()
|
rules_obj = Rules()
|
||||||
|
|
||||||
|
@ -1243,9 +1259,11 @@ def client(original_function): # noqa: PLR0915
|
||||||
_is_litellm_router_call = "model_group" in kwargs.get(
|
_is_litellm_router_call = "model_group" in kwargs.get(
|
||||||
"metadata", {}
|
"metadata", {}
|
||||||
) # check if call from litellm.router/proxy
|
) # check if call from litellm.router/proxy
|
||||||
|
|
||||||
if (
|
if (
|
||||||
num_retries and not _is_litellm_router_call
|
num_retries and not _is_litellm_router_call
|
||||||
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
||||||
|
|
||||||
try:
|
try:
|
||||||
litellm.num_retries = (
|
litellm.num_retries = (
|
||||||
None # set retries to None to prevent infinite loops
|
None # set retries to None to prevent infinite loops
|
||||||
|
@ -1266,6 +1284,7 @@ def client(original_function): # noqa: PLR0915
|
||||||
and context_window_fallback_dict
|
and context_window_fallback_dict
|
||||||
and model in context_window_fallback_dict
|
and model in context_window_fallback_dict
|
||||||
):
|
):
|
||||||
|
|
||||||
if len(args) > 0:
|
if len(args) > 0:
|
||||||
args[0] = context_window_fallback_dict[model] # type: ignore
|
args[0] = context_window_fallback_dict[model] # type: ignore
|
||||||
else:
|
else:
|
||||||
|
@ -1275,6 +1294,9 @@ def client(original_function): # noqa: PLR0915
|
||||||
setattr(
|
setattr(
|
||||||
e, "num_retries", num_retries
|
e, "num_retries", num_retries
|
||||||
) ## IMPORTANT: returns the deployment's num_retries to the router
|
) ## IMPORTANT: returns the deployment's num_retries to the router
|
||||||
|
|
||||||
|
timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e)
|
||||||
|
setattr(e, "timeout", timeout)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||||
|
|
|
@ -74,6 +74,12 @@ model_list:
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
stream_timeout: 0.001
|
stream_timeout: 0.001
|
||||||
rpm: 1000
|
rpm: 1000
|
||||||
|
- model_name: fake-openai-endpoint-4
|
||||||
|
litellm_params:
|
||||||
|
model: openai/my-fake-model
|
||||||
|
api_key: my-fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
num_retries: 50
|
||||||
- model_name: fake-openai-endpoint-3
|
- model_name: fake-openai-endpoint-3
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/my-fake-model-2
|
model: openai/my-fake-model-2
|
||||||
|
@ -112,6 +118,12 @@ model_list:
|
||||||
- model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model
|
- model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: text-completion-openai/gpt-3.5-turbo-instruct
|
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||||
|
- model_name: fake-openai-endpoint-5
|
||||||
|
litellm_params:
|
||||||
|
model: openai/my-fake-model
|
||||||
|
api_key: my-fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
timeout: 1
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
|
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
|
||||||
drop_params: True
|
drop_params: True
|
||||||
|
|
|
@ -2742,3 +2742,22 @@ def test_router_prompt_management_factory():
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_get_model_list_from_model_alias():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
model_group_alias={
|
||||||
|
"my-special-fake-model-alias-name": "fake-openai-endpoint-3"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_alias_list = router.get_model_list_from_model_alias(
|
||||||
|
model_name="gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
assert len(model_alias_list) == 0
|
||||||
|
|
|
@ -4,6 +4,7 @@ import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from large_text import text
|
from large_text import text
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
async def generate_key(
|
async def generate_key(
|
||||||
|
@ -37,7 +38,14 @@ async def generate_key(
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion(session, key: str, model: str, messages: list, **kwargs):
|
async def chat_completion(
|
||||||
|
session,
|
||||||
|
key: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
return_headers: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
url = "http://0.0.0.0:4000/chat/completions"
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
|
@ -53,8 +61,15 @@ async def chat_completion(session, key: str, model: str, messages: list, **kwarg
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
if return_headers:
|
||||||
return await response.json()
|
return None, response.headers
|
||||||
|
else:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
if return_headers:
|
||||||
|
return await response.json(), response.headers
|
||||||
|
else:
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -113,6 +128,58 @@ async def test_chat_completion_client_fallbacks(has_access):
|
||||||
pytest.fail("Expected this to work: {}".format(str(e)))
|
pytest.fail("Expected this to work: {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_retries():
|
||||||
|
"""
|
||||||
|
make chat completion call with prompt > context window. expect it to work with fallback
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
model = "fake-openai-endpoint-4"
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": text},
|
||||||
|
{"role": "user", "content": "Who was Alexander?"},
|
||||||
|
]
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session=session,
|
||||||
|
key="sk-1234",
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
mock_testing_rate_limit_error=True,
|
||||||
|
return_headers=True,
|
||||||
|
)
|
||||||
|
print(f"headers: {headers}")
|
||||||
|
assert headers["x-litellm-attempted-retries"] == "1"
|
||||||
|
assert headers["x-litellm-max-retries"] == "50"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_timeout():
|
||||||
|
"""
|
||||||
|
make chat completion call with low timeout and `mock_timeout`: true. Expect it to fail and correct timeout to be set in headers.
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
model = "fake-openai-endpoint-5"
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": text},
|
||||||
|
{"role": "user", "content": "Who was Alexander?"},
|
||||||
|
]
|
||||||
|
start_time = time.time()
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session=session,
|
||||||
|
key="sk-1234",
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
num_retries=0,
|
||||||
|
mock_timeout=True,
|
||||||
|
return_headers=True,
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"headers: {headers}")
|
||||||
|
assert (
|
||||||
|
headers["x-litellm-timeout"] == "1.0"
|
||||||
|
) # assert model-specific timeout used
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("has_access", [True, False])
|
@pytest.mark.parametrize("has_access", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_client_fallbacks_with_custom_message(has_access):
|
async def test_chat_completion_client_fallbacks_with_custom_message(has_access):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue