diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a5fbf8c6a0..da3644a63a 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -5,11 +5,12 @@ model_list: api_key: os.environ/AZURE_API_KEY api_base: os.environ/AZURE_API_BASE temperature: 0.2 - - model_name: "*" + - model_name: gpt-4o litellm_params: - model: "*" - model_info: - access_groups: ["default"] + model: openai/gpt-4o + api_key: os.environ/OPENAI_API_KEY + num_retries: 3 litellm_settings: - success_callback: ["langsmith"] \ No newline at end of file + success_callback: ["langsmith"] + num_retries: 0 \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 0fd39caa05..ece6a97598 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -813,7 +813,6 @@ class Router: kwargs["messages"] = messages kwargs["stream"] = stream kwargs["original_function"] = self._acompletion - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) request_priority = kwargs.get("priority") or self.default_priority @@ -899,17 +898,12 @@ class Router: ) self.total_calls[model_name] += 1 - timeout: Optional[Union[float, int]] = self._get_timeout( - kwargs=kwargs, data=data - ) - _response = litellm.acompletion( **{ **data, "messages": messages, "caching": self.cache_responses, "client": model_client, - "timeout": timeout, **kwargs, } ) @@ -1015,6 +1009,10 @@ class Router: } ) kwargs["model_info"] = deployment.get("model_info", {}) + kwargs["timeout"] = self._get_timeout( + kwargs=kwargs, data=deployment["litellm_params"] + ) + self._update_kwargs_with_default_litellm_params(kwargs=kwargs) def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): @@ -1046,16 +1044,16 @@ class Router: def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]: """Helper to get timeout from kwargs or deployment params""" timeout = ( - data.get( + kwargs.get("timeout", None) # the params dynamically set by user + or kwargs.get("request_timeout", None) # the params dynamically set by user + or data.get( "timeout", None ) # timeout set on litellm_params for this deployment or data.get( "request_timeout", None ) # timeout set on litellm_params for this deployment or self.timeout # timeout set on router - or kwargs.get( - "timeout", None - ) # this uses default_litellm_params when nothing is set + or self.default_litellm_params.get("timeout", None) ) return timeout @@ -1378,7 +1376,6 @@ class Router: kwargs["prompt"] = prompt kwargs["original_function"] = self._image_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = self.function_with_fallbacks(**kwargs) @@ -1438,7 +1435,6 @@ class Router: kwargs["prompt"] = prompt kwargs["original_function"] = self._aimage_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) @@ -1768,17 +1764,11 @@ class Router: ) self.total_calls[model_name] += 1 - timeout: Optional[Union[float, int]] = self._get_timeout( - kwargs=kwargs, - data=data, - ) - response = await litellm.arerank( **{ **data, "caching": self.cache_responses, "client": model_client, - "timeout": timeout, **kwargs, } ) @@ -1800,7 +1790,6 @@ class Router: messages = [{"role": "user", "content": "dummy-text"}] try: kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) # pick the one that is available (lowest TPM/RPM) @@ -1826,7 +1815,7 @@ class Router: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._arealtime - return self.function_with_retries(**kwargs) + return await self.async_function_with_retries(**kwargs) else: raise e @@ -1844,7 +1833,6 @@ class Router: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) # pick the one that is available (lowest TPM/RPM) @@ -1924,7 +1912,6 @@ class Router: "prompt": prompt, "caching": self.cache_responses, "client": model_client, - "timeout": self.timeout, **kwargs, } ) @@ -1980,7 +1967,6 @@ class Router: kwargs["adapter_id"] = adapter_id kwargs["original_function"] = self._aadapter_completion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) @@ -2024,7 +2010,6 @@ class Router: "adapter_id": adapter_id, "caching": self.cache_responses, "client": model_client, - "timeout": self.timeout, **kwargs, } ) @@ -2078,7 +2063,6 @@ class Router: kwargs["input"] = input kwargs["original_function"] = self._embedding kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = self.function_with_fallbacks(**kwargs) return response @@ -2245,7 +2229,6 @@ class Router: kwargs["model"] = model kwargs["original_function"] = self._acreate_file kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) @@ -2301,7 +2284,6 @@ class Router: "custom_llm_provider": custom_llm_provider, "caching": self.cache_responses, "client": model_client, - "timeout": self.timeout, **kwargs, } ) @@ -2352,7 +2334,6 @@ class Router: kwargs["model"] = model kwargs["original_function"] = self._acreate_batch kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) @@ -2397,13 +2378,7 @@ class Router: kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == metadata_variable_name: - kwargs[k].update(v) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) model_client = self._get_async_openai_model_client( deployment=deployment, @@ -2420,7 +2395,6 @@ class Router: "custom_llm_provider": custom_llm_provider, "caching": self.cache_responses, "client": model_client, - "timeout": self.timeout, **kwargs, } ) @@ -2913,7 +2887,11 @@ class Router: except Exception as e: current_attempt = None original_exception = e - + deployment_num_retries = getattr(e, "num_retries", None) + if deployment_num_retries is not None and isinstance( + deployment_num_retries, int + ): + num_retries = deployment_num_retries """ Retry Logic """ @@ -3210,106 +3188,6 @@ class Router: return timeout - def function_with_retries(self, *args, **kwargs): - """ - Try calling the model 3 times. Shuffle-between available deployments. - """ - verbose_router_logger.debug( - f"Inside function with retries: args - {args}; kwargs - {kwargs}" - ) - original_function = kwargs.pop("original_function") - num_retries = kwargs.pop("num_retries") - fallbacks = kwargs.pop("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.pop( - "context_window_fallbacks", self.context_window_fallbacks - ) - content_policy_fallbacks = kwargs.pop( - "content_policy_fallbacks", self.content_policy_fallbacks - ) - model_group = kwargs.get("model") - - try: - # if the function call is successful, no exception will be raised and we'll break out of the loop - self._handle_mock_testing_rate_limit_error( - kwargs=kwargs, model_group=model_group - ) - response = original_function(*args, **kwargs) - return response - except Exception as e: - current_attempt = None - original_exception = e - _model: Optional[str] = kwargs.get("model") # type: ignore - - if _model is None: - raise e # re-raise error, if model can't be determined for loadbalancing - ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) - _healthy_deployments, _all_deployments = self._get_healthy_deployments( - model=_model, - parent_otel_span=parent_otel_span, - ) - - # raises an exception if this error should not be retries - self.should_retry_this_error( - error=e, - healthy_deployments=_healthy_deployments, - all_deployments=_all_deployments, - context_window_fallbacks=context_window_fallbacks, - regular_fallbacks=fallbacks, - content_policy_fallbacks=content_policy_fallbacks, - ) - - # decides how long to sleep before retry - _timeout = self._time_to_sleep_before_retry( - e=original_exception, - remaining_retries=num_retries, - num_retries=num_retries, - healthy_deployments=_healthy_deployments, - all_deployments=_all_deployments, - ) - - ## LOGGING - if num_retries > 0: - kwargs = self.log_retry(kwargs=kwargs, e=original_exception) - - time.sleep(_timeout) - for current_attempt in range(num_retries): - verbose_router_logger.debug( - f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" - ) - try: - # if the function call is successful, no exception will be raised and we'll break out of the loop - response = original_function(*args, **kwargs) - return response - - except Exception as e: - ## LOGGING - kwargs = self.log_retry(kwargs=kwargs, e=e) - _model: Optional[str] = kwargs.get("model") # type: ignore - - if _model is None: - raise e # re-raise error, if model can't be determined for loadbalancing - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) - _healthy_deployments, _ = self._get_healthy_deployments( - model=_model, - parent_otel_span=parent_otel_span, - ) - remaining_retries = num_retries - current_attempt - _timeout = self._time_to_sleep_before_retry( - e=e, - remaining_retries=remaining_retries, - num_retries=num_retries, - healthy_deployments=_healthy_deployments, - all_deployments=_all_deployments, - ) - time.sleep(_timeout) - - if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: - setattr(original_exception, "max_retries", num_retries) - setattr(original_exception, "num_retries", current_attempt) - - raise original_exception - ### HELPER FUNCTIONS async def deployment_callback_on_success( diff --git a/litellm/utils.py b/litellm/utils.py index 6f662f6595..8957304c94 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -689,6 +689,19 @@ def client(original_function): # noqa: PLR0915 except Exception as e: raise e + def _get_num_retries( + kwargs: Dict[str, Any], exception: Exception + ) -> Tuple[Optional[int], Dict[str, Any]]: + num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None + if kwargs.get("retry_policy", None): + num_retries = get_num_retries_from_retry_policy( + exception=exception, + retry_policy=kwargs.get("retry_policy"), + ) + kwargs["retry_policy"] = reset_retry_policy() + + return num_retries, kwargs + @wraps(original_function) def wrapper(*args, **kwargs): # noqa: PLR0915 # DO NOT MOVE THIS. It always needs to run first @@ -1159,20 +1172,8 @@ def client(original_function): # noqa: PLR0915 raise e call_type = original_function.__name__ + num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e) if call_type == CallTypes.acompletion.value: - num_retries = ( - kwargs.get("num_retries", None) or litellm.num_retries or None - ) - if kwargs.get("retry_policy", None): - num_retries = get_num_retries_from_retry_policy( - exception=e, - retry_policy=kwargs.get("retry_policy"), - ) - kwargs["retry_policy"] = reset_retry_policy() - - litellm.num_retries = ( - None # set retries to None to prevent infinite loops - ) context_window_fallback_dict = kwargs.get( "context_window_fallback_dict", {} ) @@ -1184,6 +1185,9 @@ def client(original_function): # noqa: PLR0915 num_retries and not _is_litellm_router_call ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying try: + litellm.num_retries = ( + None # set retries to None to prevent infinite loops + ) kwargs["num_retries"] = num_retries kwargs["original_function"] = original_function if isinstance( @@ -1205,6 +1209,10 @@ def client(original_function): # noqa: PLR0915 else: kwargs["model"] = context_window_fallback_dict[model] return await original_function(*args, **kwargs) + + setattr( + e, "num_retries", num_retries + ) ## IMPORTANT: returns the deployment's num_retries to the router raise e is_coroutine = inspect.iscoroutinefunction(original_function) diff --git a/tests/local_testing/test_router_retries.py b/tests/local_testing/test_router_retries.py index 24b46b6549..12bd71cfd1 100644 --- a/tests/local_testing/test_router_retries.py +++ b/tests/local_testing/test_router_retries.py @@ -733,3 +733,73 @@ def test_no_retry_when_no_healthy_deployments(): ) except Exception as e: print("got exception", e) + + +@pytest.mark.asyncio +async def test_router_retries_model_specific_and_global(): + from unittest.mock import patch, MagicMock + + litellm.num_retries = 0 + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + "num_retries": 1, + }, + } + ] + ) + + with patch.object( + router, "_time_to_sleep_before_retry" + ) as mock_async_function_with_retries: + try: + await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="litellm.RateLimitError", + ) + except Exception as e: + print("got exception", e) + + mock_async_function_with_retries.assert_called_once() + + assert mock_async_function_with_retries.call_args.kwargs["num_retries"] == 1 + + +@pytest.mark.asyncio +async def test_router_timeout_model_specific_and_global(): + from unittest.mock import patch, MagicMock + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + router = Router( + model_list=[ + { + "model_name": "anthropic-claude", + "litellm_params": { + "model": "anthropic/claude-3-5-sonnet-20240620", + "timeout": 1, + }, + } + ], + timeout=10, + ) + + client = HTTPHandler() + + with patch.object(client, "post") as mock_client: + try: + await router.acompletion( + model="anthropic-claude", + messages=[{"role": "user", "content": "Hello, how are you?"}], + client=client, + ) + except Exception as e: + print("got exception", e) + + mock_client.assert_called() + + assert mock_client.call_args.kwargs["timeout"] == 1 diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index d73343137f..1961296630 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -217,16 +217,11 @@ async def test_router_function_with_retries(model_list, sync_mode): "mock_response": "I'm fine, thank you!", "num_retries": 0, } - if sync_mode: - response = router.function_with_retries( - original_function=router._completion, - **data, - ) - else: - response = await router.async_function_with_retries( - original_function=router._acompletion, - **data, - ) + response = await router.async_function_with_retries( + original_function=router._acompletion, + **data, + ) + assert response.choices[0].message.content == "I'm fine, thank you!"