diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 9933ff64b7..15acf2a778 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,10 +1,11 @@ model_list: - - model_name: openai-gpt-4o + - model_name: gpt-3.5-turbo-end-user-test litellm_params: - model: openai/my-fake-openai-endpoint - api_key: sk-1234 - api_base: https://exampleopenaiendpoint-production.up.railway.app - - model_name: openai-o1 + model: gpt-3.5-turbo + region_name: "eu" + model_info: + id: "1" + - model_name: gpt-3.5-turbo-end-user-test litellm_params: model: openai/random_sleep api_base: http://0.0.0.0:8090 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2c126e54c0..4fb589fe37 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -787,9 +787,10 @@ def get_custom_headers( hidden_params: Optional[dict] = None, fastest_response_batch_completion: Optional[bool] = None, request_data: Optional[dict] = {}, + timeout: Optional[Union[float, int, httpx.Timeout]] = None, **kwargs, ) -> dict: - exclude_values = {"", None} + exclude_values = {"", None, "None"} hidden_params = hidden_params or {} headers = { "x-litellm-call-id": call_id, @@ -812,6 +813,7 @@ def get_custom_headers( if fastest_response_batch_completion is not None else None ), + "x-litellm-timeout": str(timeout) if timeout is not None else None, **{k: str(v) for k, v in kwargs.items()}, } if request_data: @@ -3638,14 +3640,28 @@ async def chat_completion( # noqa: PLR0915 litellm_debug_info, ) + timeout = getattr( + e, "timeout", None + ) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly + + custom_headers = get_custom_headers( + user_api_key_dict=user_api_key_dict, + version=version, + response_cost=0, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + request_data=data, + timeout=timeout, + ) + headers = getattr(e, "headers", {}) or {} + headers.update(custom_headers) + if isinstance(e, HTTPException): - # print("e.headers={}".format(e.headers)) raise ProxyException( message=getattr(e, "detail", str(e)), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - headers=getattr(e, "headers", {}), + headers=headers, ) error_msg = f"{str(e)}" raise ProxyException( @@ -3653,7 +3669,7 @@ async def chat_completion( # noqa: PLR0915 type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), - headers=getattr(e, "headers", {}), + headers=headers, ) diff --git a/litellm/router.py b/litellm/router.py index 794a2a5404..58809197ee 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -57,6 +57,7 @@ from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.router_strategy.simple_shuffle import simple_shuffle from litellm.router_strategy.tag_based_routing import get_deployments_for_tag +from litellm.router_utils.add_retry_headers import add_retry_headers_to_response from litellm.router_utils.batch_utils import ( _get_router_metadata_variable_name, replace_model_in_jsonl, @@ -3090,12 +3091,15 @@ class Router: ) # if the function call is successful, no exception will be raised and we'll break out of the loop response = await self.make_call(original_function, *args, **kwargs) - + response = add_retry_headers_to_response( + response=response, attempted_retries=0, max_retries=None + ) return response except Exception as e: current_attempt = None original_exception = e deployment_num_retries = getattr(e, "num_retries", None) + if deployment_num_retries is not None and isinstance( deployment_num_retries, int ): @@ -3156,6 +3160,12 @@ class Router: response ): # async errors are often returned as coroutines response = await response + + response = add_retry_headers_to_response( + response=response, + attempted_retries=current_attempt + 1, + max_retries=num_retries, + ) return response except Exception as e: @@ -3214,6 +3224,15 @@ class Router: mock_testing_rate_limit_error: Optional[bool] = kwargs.pop( "mock_testing_rate_limit_error", None ) + + available_models = self.get_model_list(model_name=model_group) + num_retries: Optional[int] = None + + if available_models is not None and len(available_models) == 1: + num_retries = cast( + Optional[int], available_models[0]["litellm_params"].get("num_retries") + ) + if ( mock_testing_rate_limit_error is not None and mock_testing_rate_limit_error is True @@ -3225,6 +3244,7 @@ class Router: model=model_group, llm_provider="", message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", + num_retries=num_retries, ) def should_retry_this_error( @@ -4776,6 +4796,37 @@ class Router: model_names.append(m["model_name"]) return model_names + def get_model_list_from_model_alias( + self, model_name: Optional[str] = None + ) -> List[DeploymentTypedDict]: + """ + Helper function to get model list from model alias. + + Used by `.get_model_list` to get model list from model alias. + """ + returned_models: List[DeploymentTypedDict] = [] + for model_alias, model_value in self.model_group_alias.items(): + if model_name is not None and model_alias != model_name: + continue + if isinstance(model_value, str): + _router_model_name: str = model_value + elif isinstance(model_value, dict): + _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore + if _model_value["hidden"] is True: + continue + else: + _router_model_name = _model_value["model"] + else: + continue + + returned_models.extend( + self._get_all_deployments( + model_name=_router_model_name, model_alias=model_alias + ) + ) + + return returned_models + def get_model_list( self, model_name: Optional[str] = None ) -> Optional[List[DeploymentTypedDict]]: @@ -4789,24 +4840,9 @@ class Router: returned_models.extend(self._get_all_deployments(model_name=model_name)) if hasattr(self, "model_group_alias"): - for model_alias, model_value in self.model_group_alias.items(): - - if isinstance(model_value, str): - _router_model_name: str = model_value - elif isinstance(model_value, dict): - _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore - if _model_value["hidden"] is True: - continue - else: - _router_model_name = _model_value["model"] - else: - continue - - returned_models.extend( - self._get_all_deployments( - model_name=_router_model_name, model_alias=model_alias - ) - ) + returned_models.extend( + self.get_model_list_from_model_alias(model_name=model_name) + ) if len(returned_models) == 0: # check if wildcard route potential_wildcard_models = self.pattern_router.route(model_name) diff --git a/litellm/router_utils/add_retry_headers.py b/litellm/router_utils/add_retry_headers.py new file mode 100644 index 0000000000..c2771e4939 --- /dev/null +++ b/litellm/router_utils/add_retry_headers.py @@ -0,0 +1,40 @@ +from typing import Any, Optional, Union + +from pydantic import BaseModel + +from litellm.types.utils import HiddenParams + + +def add_retry_headers_to_response( + response: Any, + attempted_retries: int, + max_retries: Optional[int] = None, +) -> Any: + """ + Add retry headers to the request + """ + + if response is None or not isinstance(response, BaseModel): + return response + + retry_headers = { + "x-litellm-attempted-retries": attempted_retries, + } + if max_retries is not None: + retry_headers["x-litellm-max-retries"] = max_retries + + hidden_params: Optional[Union[dict, HiddenParams]] = getattr( + response, "_hidden_params", {} + ) + + if hidden_params is None: + hidden_params = {} + elif isinstance(hidden_params, HiddenParams): + hidden_params = hidden_params.model_dump() + + hidden_params.setdefault("additional_headers", {}) + hidden_params["additional_headers"].update(retry_headers) + + setattr(response, "_hidden_params", hidden_params) + + return response diff --git a/litellm/types/router.py b/litellm/types/router.py index 1575db27e0..9393bb2213 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -352,6 +352,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): output_cost_per_token: Optional[float] input_cost_per_second: Optional[float] output_cost_per_second: Optional[float] + num_retries: Optional[int] ## MOCK RESPONSES ## mock_response: Optional[Union[str, ModelResponse, Exception]] diff --git a/litellm/utils.py b/litellm/utils.py index f3fec6c2a0..a0f452c6b4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -669,6 +669,7 @@ def _get_wrapper_num_retries( Get the number of retries from the kwargs and the retry policy. Used for the wrapper functions. """ + num_retries = kwargs.get("num_retries", None) if num_retries is None: num_retries = litellm.num_retries @@ -684,6 +685,21 @@ def _get_wrapper_num_retries( return num_retries, kwargs +def _get_wrapper_timeout( + kwargs: Dict[str, Any], exception: Exception +) -> Optional[Union[float, int, httpx.Timeout]]: + """ + Get the timeout from the kwargs + Used for the wrapper functions. + """ + + timeout = cast( + Optional[Union[float, int, httpx.Timeout]], kwargs.get("timeout", None) + ) + + return timeout + + def client(original_function): # noqa: PLR0915 rules_obj = Rules() @@ -1243,9 +1259,11 @@ def client(original_function): # noqa: PLR0915 _is_litellm_router_call = "model_group" in kwargs.get( "metadata", {} ) # check if call from litellm.router/proxy + if ( num_retries and not _is_litellm_router_call ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying + try: litellm.num_retries = ( None # set retries to None to prevent infinite loops @@ -1266,6 +1284,7 @@ def client(original_function): # noqa: PLR0915 and context_window_fallback_dict and model in context_window_fallback_dict ): + if len(args) > 0: args[0] = context_window_fallback_dict[model] # type: ignore else: @@ -1275,6 +1294,9 @@ def client(original_function): # noqa: PLR0915 setattr( e, "num_retries", num_retries ) ## IMPORTANT: returns the deployment's num_retries to the router + + timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e) + setattr(e, "timeout", timeout) raise e is_coroutine = inspect.iscoroutinefunction(original_function) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 2dfc7cd3ff..04a5625086 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -74,6 +74,12 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ stream_timeout: 0.001 rpm: 1000 + - model_name: fake-openai-endpoint-4 + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + num_retries: 50 - model_name: fake-openai-endpoint-3 litellm_params: model: openai/my-fake-model-2 @@ -112,6 +118,12 @@ model_list: - model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model litellm_params: model: text-completion-openai/gpt-3.5-turbo-instruct + - model_name: fake-openai-endpoint-5 + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + timeout: 1 litellm_settings: # set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production drop_params: True diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index ec72d00ba4..1ef7607c26 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2742,3 +2742,22 @@ def test_router_prompt_management_factory(): ) print(response) + + +def test_router_get_model_list_from_model_alias(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + } + ], + model_group_alias={ + "my-special-fake-model-alias-name": "fake-openai-endpoint-3" + }, + ) + + model_alias_list = router.get_model_list_from_model_alias( + model_name="gpt-3.5-turbo" + ) + assert len(model_alias_list) == 0 diff --git a/tests/test_fallbacks.py b/tests/test_fallbacks.py index b2e27689ae..2f39d5e985 100644 --- a/tests/test_fallbacks.py +++ b/tests/test_fallbacks.py @@ -4,6 +4,7 @@ import pytest import asyncio import aiohttp from large_text import text +import time async def generate_key( @@ -37,7 +38,14 @@ async def generate_key( return await response.json() -async def chat_completion(session, key: str, model: str, messages: list, **kwargs): +async def chat_completion( + session, + key: str, + model: str, + messages: list, + return_headers: bool = False, + **kwargs, +): url = "http://0.0.0.0:4000/chat/completions" headers = { "Authorization": f"Bearer {key}", @@ -53,8 +61,15 @@ async def chat_completion(session, key: str, model: str, messages: list, **kwarg print() if status != 200: - raise Exception(f"Request did not return a 200 status code: {status}") - return await response.json() + if return_headers: + return None, response.headers + else: + raise Exception(f"Request did not return a 200 status code: {status}") + + if return_headers: + return await response.json(), response.headers + else: + return await response.json() @pytest.mark.asyncio @@ -113,6 +128,58 @@ async def test_chat_completion_client_fallbacks(has_access): pytest.fail("Expected this to work: {}".format(str(e))) +@pytest.mark.asyncio +async def test_chat_completion_with_retries(): + """ + make chat completion call with prompt > context window. expect it to work with fallback + """ + async with aiohttp.ClientSession() as session: + model = "fake-openai-endpoint-4" + messages = [ + {"role": "system", "content": text}, + {"role": "user", "content": "Who was Alexander?"}, + ] + response, headers = await chat_completion( + session=session, + key="sk-1234", + model=model, + messages=messages, + mock_testing_rate_limit_error=True, + return_headers=True, + ) + print(f"headers: {headers}") + assert headers["x-litellm-attempted-retries"] == "1" + assert headers["x-litellm-max-retries"] == "50" + + +@pytest.mark.asyncio +async def test_chat_completion_with_timeout(): + """ + make chat completion call with low timeout and `mock_timeout`: true. Expect it to fail and correct timeout to be set in headers. + """ + async with aiohttp.ClientSession() as session: + model = "fake-openai-endpoint-5" + messages = [ + {"role": "system", "content": text}, + {"role": "user", "content": "Who was Alexander?"}, + ] + start_time = time.time() + response, headers = await chat_completion( + session=session, + key="sk-1234", + model=model, + messages=messages, + num_retries=0, + mock_timeout=True, + return_headers=True, + ) + end_time = time.time() + print(f"headers: {headers}") + assert ( + headers["x-litellm-timeout"] == "1.0" + ) # assert model-specific timeout used + + @pytest.mark.parametrize("has_access", [True, False]) @pytest.mark.asyncio async def test_chat_completion_client_fallbacks_with_custom_message(has_access):