fix(router.py): enable dynamic retry after in exception string

Updates cooldown logic to cooldown individual models

 Closes https://github.com/BerriAI/litellm/issues/1339
This commit is contained in:
Krrish Dholakia 2024-08-24 16:59:30 -07:00
parent 76834c6c59
commit 33972cc79c
3 changed files with 271 additions and 91 deletions

View file

@ -58,6 +58,7 @@ from litellm.router_utils.client_initalization_utils import (
set_client,
should_initialize_sync_client,
)
from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
from litellm.router_utils.fallback_event_handlers import (
log_failure_fallback_event,
@ -338,6 +339,9 @@ class Router:
else:
self.allowed_fails = litellm.allowed_fails
self.cooldown_time = cooldown_time or 60
self.cooldown_cache = CooldownCache(
cache=self.cache, default_cooldown_time=self.cooldown_time
)
self.disable_cooldowns = disable_cooldowns
self.failed_calls = (
InMemoryCache()
@ -3243,52 +3247,14 @@ class Router:
if updated_fails > allowed_fails or _should_retry is False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(
key=cooldown_key
) # [(deployment_id, {last_error_str, last_error_status_code})]
cached_value_deployment_ids = []
if (
cached_value is not None
and isinstance(cached_value, list)
and len(cached_value) > 0
and isinstance(cached_value[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cached_value]
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
# update value
if cached_value is not None and len(cached_value_deployment_ids) > 0:
if deployment in cached_value_deployment_ids:
pass
else:
cached_value = cached_value + [
(
deployment,
{
"Exception Received": str(original_exception),
"Status Code": str(exception_status),
},
)
]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
else:
cached_value = [
(
deployment,
{
"Exception Received": str(original_exception),
"Status Code": str(exception_status),
},
)
]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
self.cooldown_cache.add_deployment_to_cooldown(
model_id=deployment,
original_exception=original_exception,
exception_status=exception_status,
cooldown_time=cooldown_time,
)
# Trigger cooldown handler
asyncio.create_task(
@ -3308,15 +3274,10 @@ class Router:
"""
Async implementation of '_get_cooldown_deployments'
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
model_ids = self.get_model_ids()
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids
)
cached_value_deployment_ids = []
if (
@ -3334,15 +3295,10 @@ class Router:
"""
Async implementation of '_get_cooldown_deployments'
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
model_ids = self.get_model_ids()
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids
)
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
@ -3351,15 +3307,13 @@ class Router:
"""
Get the list of models being cooled down for this minute
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
model_ids = self.get_model_ids()
cooldown_models = self.cooldown_cache.get_active_cooldowns(model_ids=model_ids)
cached_value_deployment_ids = []
if (
@ -3370,7 +3324,6 @@ class Router:
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cached_value_deployment_ids
def _get_healthy_deployments(self, model: str):
@ -4061,15 +4014,20 @@ class Router:
rpm_usage += t
return tpm_usage, rpm_usage
def get_model_ids(self) -> List[str]:
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
"""
if 'model_name' is none, returns all.
Returns list of model id's.
"""
ids = []
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
id = model["model_info"]["id"]
ids.append(id)
if model_name is not None and model["model_name"] == model_name:
ids.append(id)
elif model_name is None:
ids.append(id)
return ids
def get_model_names(self) -> List[str]:
@ -4402,10 +4360,19 @@ class Router:
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
"""
if _rate_limit_error == True: # allow generic fallback logic to take place
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
if _rate_limit_error is True: # allow generic fallback logic to take place
model_ids = self.get_model_ids(model_name=model)
cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=cooldown_time,
enable_pre_call_checks=True,
cooldown_list=cooldown_list,
)
elif _context_window_error is True:
raise litellm.ContextWindowExceededError(
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
@ -4514,8 +4481,14 @@ class Router:
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
if len(healthy_deployments) == 0:
raise ValueError(
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)
if litellm.model_alias_map and model in litellm.model_alias_map:
@ -4602,8 +4575,16 @@ class Router:
if len(healthy_deployments) == 0:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)
if (
@ -4682,8 +4663,16 @@ class Router:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -4755,7 +4744,8 @@ class Router:
)
if len(healthy_deployments) == 0:
_cooldown_time = self.cooldown_time # [TODO] Make dynamic
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
@ -4841,8 +4831,14 @@ class Router:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"