fix(router.py): enable dynamic retry after in exception string

Updates cooldown logic to cooldown individual models

 Closes https://github.com/BerriAI/litellm/issues/1339
This commit is contained in:
Krrish Dholakia 2024-08-24 16:59:30 -07:00
parent b8ca6553b6
commit c795e9feeb
3 changed files with 271 additions and 91 deletions

View file

@ -58,6 +58,7 @@ from litellm.router_utils.client_initalization_utils import (
set_client, set_client,
should_initialize_sync_client, should_initialize_sync_client,
) )
from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_callbacks import router_cooldown_handler from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
from litellm.router_utils.fallback_event_handlers import ( from litellm.router_utils.fallback_event_handlers import (
log_failure_fallback_event, log_failure_fallback_event,
@ -338,6 +339,9 @@ class Router:
else: else:
self.allowed_fails = litellm.allowed_fails self.allowed_fails = litellm.allowed_fails
self.cooldown_time = cooldown_time or 60 self.cooldown_time = cooldown_time or 60
self.cooldown_cache = CooldownCache(
cache=self.cache, default_cooldown_time=self.cooldown_time
)
self.disable_cooldowns = disable_cooldowns self.disable_cooldowns = disable_cooldowns
self.failed_calls = ( self.failed_calls = (
InMemoryCache() InMemoryCache()
@ -3243,52 +3247,14 @@ class Router:
if updated_fails > allowed_fails or _should_retry is False: if updated_fails > allowed_fails or _should_retry is False:
# get the current cooldown list for that minute # get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(
key=cooldown_key
) # [(deployment_id, {last_error_str, last_error_status_code})]
cached_value_deployment_ids = []
if (
cached_value is not None
and isinstance(cached_value, list)
and len(cached_value) > 0
and isinstance(cached_value[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cached_value]
verbose_router_logger.debug(f"adding {deployment} to cooldown models") verbose_router_logger.debug(f"adding {deployment} to cooldown models")
# update value # update value
if cached_value is not None and len(cached_value_deployment_ids) > 0: self.cooldown_cache.add_deployment_to_cooldown(
if deployment in cached_value_deployment_ids: model_id=deployment,
pass original_exception=original_exception,
else: exception_status=exception_status,
cached_value = cached_value + [ cooldown_time=cooldown_time,
( )
deployment,
{
"Exception Received": str(original_exception),
"Status Code": str(exception_status),
},
)
]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
else:
cached_value = [
(
deployment,
{
"Exception Received": str(original_exception),
"Status Code": str(exception_status),
},
)
]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
# Trigger cooldown handler # Trigger cooldown handler
asyncio.create_task( asyncio.create_task(
@ -3308,15 +3274,10 @@ class Router:
""" """
Async implementation of '_get_cooldown_deployments' Async implementation of '_get_cooldown_deployments'
""" """
dt = get_utc_datetime() model_ids = self.get_model_ids()
current_minute = dt.strftime("%H-%M") cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
# get the current cooldown list for that minute model_ids=model_ids
cooldown_key = f"{current_minute}:cooldown_models" )
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
cached_value_deployment_ids = [] cached_value_deployment_ids = []
if ( if (
@ -3334,15 +3295,10 @@ class Router:
""" """
Async implementation of '_get_cooldown_deployments' Async implementation of '_get_cooldown_deployments'
""" """
dt = get_utc_datetime() model_ids = self.get_model_ids()
current_minute = dt.strftime("%H-%M") cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
# get the current cooldown list for that minute model_ids=model_ids
cooldown_key = f"{current_minute}:cooldown_models" )
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models return cooldown_models
@ -3351,15 +3307,13 @@ class Router:
""" """
Get the list of models being cooled down for this minute Get the list of models being cooled down for this minute
""" """
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute # get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ---------------------- # ----------------------
# Return cooldown models # Return cooldown models
# ---------------------- # ----------------------
cooldown_models = self.cache.get_cache(key=cooldown_key) or [] model_ids = self.get_model_ids()
cooldown_models = self.cooldown_cache.get_active_cooldowns(model_ids=model_ids)
cached_value_deployment_ids = [] cached_value_deployment_ids = []
if ( if (
@ -3370,7 +3324,6 @@ class Router:
): ):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models] cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cached_value_deployment_ids return cached_value_deployment_ids
def _get_healthy_deployments(self, model: str): def _get_healthy_deployments(self, model: str):
@ -4061,15 +4014,20 @@ class Router:
rpm_usage += t rpm_usage += t
return tpm_usage, rpm_usage return tpm_usage, rpm_usage
def get_model_ids(self) -> List[str]: def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
""" """
if 'model_name' is none, returns all.
Returns list of model id's. Returns list of model id's.
""" """
ids = [] ids = []
for model in self.model_list: for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]: if "model_info" in model and "id" in model["model_info"]:
id = model["model_info"]["id"] id = model["model_info"]["id"]
ids.append(id) if model_name is not None and model["model_name"] == model_name:
ids.append(id)
elif model_name is None:
ids.append(id)
return ids return ids
def get_model_names(self) -> List[str]: def get_model_names(self) -> List[str]:
@ -4402,10 +4360,19 @@ class Router:
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check) - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
""" """
if _rate_limit_error == True: # allow generic fallback logic to take place if _rate_limit_error is True: # allow generic fallback logic to take place
raise ValueError( model_ids = self.get_model_ids(model_name=model)
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds." cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
) )
cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=cooldown_time,
enable_pre_call_checks=True,
cooldown_list=cooldown_list,
)
elif _context_window_error is True: elif _context_window_error is True:
raise litellm.ContextWindowExceededError( raise litellm.ContextWindowExceededError(
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format( message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
@ -4514,8 +4481,14 @@ class Router:
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
raise ValueError( model_ids = self.get_model_ids(model_name=model)
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds" _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
@ -4602,8 +4575,16 @@ class Router:
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
if _allowed_model_region is None: if _allowed_model_region is None:
_allowed_model_region = "n/a" _allowed_model_region = "n/a"
raise ValueError( model_ids = self.get_model_ids(model_name=model)
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}" _cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
if ( if (
@ -4682,8 +4663,16 @@ class Router:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
raise ValueError( model_ids = self.get_model_ids(model_name=model)
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}" _cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -4755,7 +4744,8 @@ class Router:
) )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
_cooldown_time = self.cooldown_time # [TODO] Make dynamic model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments() _cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError( raise RouterRateLimitError(
model=model, model=model,
@ -4841,8 +4831,14 @@ class Router:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
raise ValueError( model_ids = self.get_model_ids(model_name=model)
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}" _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"

View file

@ -0,0 +1,152 @@
"""
Wrapper around router cache. Meant to handle model cooldown logic
"""
import json
import time
from typing import List, Optional, Tuple, TypedDict
from litellm import verbose_logger
from litellm.caching import DualCache
class CooldownCacheValue(TypedDict):
exception_received: str
status_code: str
timestamp: float
cooldown_time: float
class CooldownCache:
def __init__(self, cache: DualCache, default_cooldown_time: float):
self.cache = cache
self.default_cooldown_time = default_cooldown_time
def _common_add_cooldown_logic(
self, model_id: str, original_exception, exception_status, cooldown_time: float
) -> Tuple[str, dict]:
try:
current_time = time.time()
cooldown_key = f"deployment:{model_id}:cooldown"
# Store the cooldown information for the deployment separately
cooldown_data = CooldownCacheValue(
exception_received=str(original_exception),
status_code=str(exception_status),
timestamp=current_time,
cooldown_time=cooldown_time,
)
return cooldown_key, cooldown_data
except Exception as e:
verbose_logger.error(
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
str(e)
)
)
raise e
def add_deployment_to_cooldown(
self,
model_id: str,
original_exception: Exception,
exception_status: int,
cooldown_time: Optional[float],
):
try:
_cooldown_time = cooldown_time or self.default_cooldown_time
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
model_id=model_id,
original_exception=original_exception,
exception_status=exception_status,
cooldown_time=_cooldown_time,
)
# Set the cache with a TTL equal to the cooldown time
self.cache.set_cache(
value=cooldown_data,
key=cooldown_key,
ttl=_cooldown_time,
)
except Exception as e:
verbose_logger.error(
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
str(e)
)
)
raise e
async def async_add_deployment_to_cooldown(
self,
model_id: str,
original_exception: Exception,
exception_status: int,
cooldown_time: Optional[float],
):
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
model_id=model_id, original_exception=original_exception
)
# Set the cache with a TTL equal to the cooldown time
self.cache.set_cache(
value=cooldown_data,
key=cooldown_key,
ttl=cooldown_time or self.default_cooldown_time,
)
async def async_get_active_cooldowns(
self, model_ids: List[str]
) -> List[Tuple[str, dict]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = await self.cache.async_batch_get_cache(keys=keys)
active_cooldowns = []
# Process the results
for model_id, result in zip(model_ids, results):
if result:
active_cooldowns.append((model_id, result))
return active_cooldowns
def get_active_cooldowns(self, model_ids: List[str]) -> List[Tuple[str, dict]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = self.cache.batch_get_cache(keys=keys)
active_cooldowns = []
# Process the results
for model_id, result in zip(model_ids, results):
if result:
active_cooldowns.append((model_id, result))
return active_cooldowns
def get_min_cooldown(self, model_ids: List[str]) -> float:
"""Return min cooldown time required for a group of model id's."""
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = self.cache.batch_get_cache(keys=keys)
min_cooldown_time = self.default_cooldown_time
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result)
if cooldown_cache_value["cooldown_time"] < min_cooldown_time:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
return min_cooldown_time
# Usage example:
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
# active_cooldowns = cooldown_cache.get_active_cooldowns()

View file

@ -2254,7 +2254,9 @@ def test_router_dynamic_cooldown_correct_retry_after_time(sync_mode):
assert response_headers["retry-after"] == cooldown_time assert response_headers["retry-after"] == cooldown_time
def test_router_dynamic_cooldown_message_retry_time(): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_router_dynamic_cooldown_message_retry_time(sync_mode):
""" """
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds" User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
but Azure says to retry in at most 9s but Azure says to retry in at most 9s
@ -2294,19 +2296,49 @@ def test_router_dynamic_cooldown_message_retry_time():
): ):
for _ in range(2): for _ in range(2):
try: try:
if sync_mode:
router.embedding(
model="text-embedding-ada-002",
input="Hello world!",
client=openai_client,
)
else:
await router.aembedding(
model="text-embedding-ada-002",
input="Hello world!",
client=openai_client,
)
except litellm.RateLimitError:
pass
if sync_mode:
cooldown_deployments = router._get_cooldown_deployments()
else:
cooldown_deployments = await router._async_get_cooldown_deployments()
print(
"Cooldown deployments - {}\n{}".format(
cooldown_deployments, len(cooldown_deployments)
)
)
assert len(cooldown_deployments) > 0
exception_raised = False
try:
if sync_mode:
router.embedding( router.embedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input="Hello world!", input="Hello world!",
client=openai_client, client=openai_client,
) )
except litellm.RateLimitError: else:
pass await router.aembedding(
model="text-embedding-ada-002",
try: input="Hello world!",
router.embedding( client=openai_client,
model="text-embedding-ada-002", )
input="Hello world!",
client=openai_client,
)
except litellm.types.router.RouterRateLimitError as e: except litellm.types.router.RouterRateLimitError as e:
print(e)
exception_raised = True
assert e.cooldown_time == cooldown_time assert e.cooldown_time == cooldown_time
assert exception_raised