[Fix] Router cooldown logic - use % thresholds instead of allowed fails to cooldown deployments (#5698)

* move cooldown logic to it's own helper

* add new track deployment metrics folder

* increment success, fails for deployment in current minute

* fix cooldown logic

* fix test_aaarouter_dynamic_cooldown_message_retry_time

* fix test_single_deployment_no_cooldowns_test_prod_mock_completion_calls

* clean up get from deployment test

* fix _async_get_healthy_deployments

* add mock InternalServerError

* test deployment failing 25% requests

* add test_high_traffic_cooldowns_one_bad_deployment

* fix vertex load test

* add test for rate limit error models in cool down

* change default cooldown time

* fix cooldown message time

* fix cooldown on 429 error

* fix doc string for _should_cooldown_deployment

* fix sync cooldown logic router
This commit is contained in:
Ishaan Jaff 2024-09-14 18:01:19 -07:00 committed by GitHub
parent fc0dd3e3c2
commit 8f155327f6
11 changed files with 836 additions and 175 deletions

View file

@ -54,6 +54,13 @@ from litellm.router_utils.client_initalization_utils import (
)
from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
from litellm.router_utils.cooldown_handlers import (
DEFAULT_COOLDOWN_TIME_SECONDS,
_async_get_cooldown_deployments,
_async_get_cooldown_deployments_with_debug_info,
_get_cooldown_deployments,
_set_cooldown_deployments,
)
from litellm.router_utils.fallback_event_handlers import (
log_failure_fallback_event,
log_success_fallback_event,
@ -61,6 +68,10 @@ from litellm.router_utils.fallback_event_handlers import (
run_sync_fallback,
)
from litellm.router_utils.handle_error import send_llm_exception_alert
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
increment_deployment_failures_for_current_minute,
increment_deployment_successes_for_current_minute,
)
from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import (
Assistant,
@ -346,7 +357,7 @@ class Router:
self.allowed_fails = allowed_fails
else:
self.allowed_fails = litellm.allowed_fails
self.cooldown_time = cooldown_time or 60
self.cooldown_time = cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS
self.cooldown_cache = CooldownCache(
cache=self.cache, default_cooldown_time=self.cooldown_time
)
@ -444,6 +455,10 @@ class Router:
litellm._async_success_callback.append(self.deployment_callback_on_success)
else:
litellm._async_success_callback.append(self.deployment_callback_on_success)
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.sync_deployment_callback_on_success)
else:
litellm.success_callback = [self.sync_deployment_callback_on_success]
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
@ -3001,7 +3016,9 @@ class Router:
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
str(new_exception),
traceback.format_exc(),
await self._async_get_cooldown_deployments_with_debug_info(),
await _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance=self
),
)
)
fallback_failure_exception_str = str(new_exception)
@ -3536,6 +3553,11 @@ class Router:
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
)
increment_deployment_successes_for_current_minute(
litellm_router_instance=self,
deployment_id=id,
)
except Exception as e:
verbose_router_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
@ -3544,6 +3566,31 @@ class Router:
)
pass
def sync_deployment_callback_on_success(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
id = None
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get("model_group", None)
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
id = model_info.get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
if id is not None:
increment_deployment_successes_for_current_minute(
litellm_router_instance=self,
deployment_id=id,
)
def deployment_callback_on_failure(
self,
kwargs, # kwargs to completion
@ -3595,7 +3642,12 @@ class Router:
if isinstance(_model_info, dict):
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
increment_deployment_failures_for_current_minute(
litellm_router_instance=self,
deployment_id=deployment_id,
)
_set_cooldown_deployments(
litellm_router_instance=self,
exception_status=exception_status,
original_exception=exception,
deployment=deployment_id,
@ -3753,155 +3805,6 @@ class Router:
)
return False
def _set_cooldown_deployments(
self,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
):
"""
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or
the exception is not one that should be immediately retried (e.g. 401)
"""
if self.disable_cooldowns is True:
return
if deployment is None:
return
if (
self._is_cooldown_required(
model_id=deployment,
exception_status=exception_status,
exception_str=str(original_exception),
)
is False
):
return
if deployment in self.provider_default_deployment_ids:
return
_allowed_fails = self.get_allowed_fails_from_policy(
exception=original_exception,
)
allowed_fails = (
_allowed_fails if _allowed_fails is not None else self.allowed_fails
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment
# update the number of failed calls
# if it's > allowed fails
# cooldown deployment
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
verbose_router_logger.debug(
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None:
cooldown_time = time_to_cooldown
if isinstance(exception_status, str):
try:
exception_status = int(exception_status)
except Exception as e:
verbose_router_logger.debug(
"Unable to cast exception status to int {}. Defaulting to status=500.".format(
exception_status
)
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > allowed_fails or _should_retry is False:
# get the current cooldown list for that minute
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
# update value
self.cooldown_cache.add_deployment_to_cooldown(
model_id=deployment,
original_exception=original_exception,
exception_status=exception_status,
cooldown_time=cooldown_time,
)
# Trigger cooldown handler
asyncio.create_task(
router_cooldown_handler(
litellm_router_instance=self,
deployment_id=deployment,
exception_status=exception_status,
cooldown_time=cooldown_time,
)
)
else:
self.failed_calls.set_cache(
key=deployment, value=updated_fails, ttl=cooldown_time
)
async def _async_get_cooldown_deployments(self) -> List[str]:
"""
Async implementation of '_get_cooldown_deployments'
"""
model_ids = self.get_model_ids()
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids
)
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cached_value_deployment_ids
async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]:
"""
Async implementation of '_get_cooldown_deployments'
"""
model_ids = self.get_model_ids()
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids
)
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_cooldown_deployments(self) -> List[str]:
"""
Get the list of models being cooled down for this minute
"""
# get the current cooldown list for that minute
# ----------------------
# Return cooldown models
# ----------------------
model_ids = self.get_model_ids()
cooldown_models = self.cooldown_cache.get_active_cooldowns(model_ids=model_ids)
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
return cached_value_deployment_ids
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
@ -3913,7 +3816,7 @@ class Router:
except:
pass
unhealthy_deployments = self._get_cooldown_deployments()
unhealthy_deployments = _get_cooldown_deployments(litellm_router_instance=self)
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
@ -3930,11 +3833,13 @@ class Router:
model=model,
)
if type(_all_deployments) == dict:
return []
return [], _all_deployments
except:
pass
unhealthy_deployments = await self._async_get_cooldown_deployments()
unhealthy_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=self
)
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
@ -3992,7 +3897,8 @@ class Router:
target=logging_obj.failure_handler,
args=(e, traceback.format_exc()),
).start() # log response
self._set_cooldown_deployments(
_set_cooldown_deployments(
litellm_router_instance=self,
exception_status=e.status_code,
original_exception=e,
deployment=deployment["model_info"]["id"],
@ -5241,7 +5147,9 @@ class Router:
# filter out the deployments currently cooling down
deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = await self._async_get_cooldown_deployments()
cooldown_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=self
)
verbose_router_logger.debug(
f"async cooldown deployments: {cooldown_deployments}"
)
@ -5283,7 +5191,7 @@ class Router:
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
_cooldown_list = self._get_cooldown_deployments()
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
@ -5398,7 +5306,7 @@ class Router:
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
_cooldown_list = self._get_cooldown_deployments()
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
@ -5456,7 +5364,7 @@ class Router:
# filter out the deployments currently cooling down
deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = self._get_cooldown_deployments()
cooldown_deployments = _get_cooldown_deployments(litellm_router_instance=self)
verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}")
# Find deployments in model_list whose model_id is cooling down
for deployment in healthy_deployments:
@ -5479,7 +5387,7 @@ class Router:
if len(healthy_deployments) == 0:
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
@ -5588,7 +5496,7 @@ class Router:
)
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,