diff --git a/litellm/router.py b/litellm/router.py index 48cd4427df..a263d4cca4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -58,6 +58,7 @@ from litellm.router_utils.client_initalization_utils import ( set_client, should_initialize_sync_client, ) +from litellm.router_utils.cooldown_cache import CooldownCache from litellm.router_utils.cooldown_callbacks import router_cooldown_handler from litellm.router_utils.fallback_event_handlers import ( log_failure_fallback_event, @@ -338,6 +339,9 @@ class Router: else: self.allowed_fails = litellm.allowed_fails self.cooldown_time = cooldown_time or 60 + self.cooldown_cache = CooldownCache( + cache=self.cache, default_cooldown_time=self.cooldown_time + ) self.disable_cooldowns = disable_cooldowns self.failed_calls = ( InMemoryCache() @@ -3243,52 +3247,14 @@ class Router: if updated_fails > allowed_fails or _should_retry is False: # get the current cooldown list for that minute - cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls - cached_value = self.cache.get_cache( - key=cooldown_key - ) # [(deployment_id, {last_error_str, last_error_status_code})] - - cached_value_deployment_ids = [] - if ( - cached_value is not None - and isinstance(cached_value, list) - and len(cached_value) > 0 - and isinstance(cached_value[0], tuple) - ): - cached_value_deployment_ids = [cv[0] for cv in cached_value] verbose_router_logger.debug(f"adding {deployment} to cooldown models") # update value - if cached_value is not None and len(cached_value_deployment_ids) > 0: - if deployment in cached_value_deployment_ids: - pass - else: - cached_value = cached_value + [ - ( - deployment, - { - "Exception Received": str(original_exception), - "Status Code": str(exception_status), - }, - ) - ] - # save updated value - self.cache.set_cache( - value=cached_value, key=cooldown_key, ttl=cooldown_time - ) - else: - cached_value = [ - ( - deployment, - { - "Exception Received": str(original_exception), - "Status Code": str(exception_status), - }, - ) - ] - # save updated value - self.cache.set_cache( - value=cached_value, key=cooldown_key, ttl=cooldown_time - ) + self.cooldown_cache.add_deployment_to_cooldown( + model_id=deployment, + original_exception=original_exception, + exception_status=exception_status, + cooldown_time=cooldown_time, + ) # Trigger cooldown handler asyncio.create_task( @@ -3308,15 +3274,10 @@ class Router: """ Async implementation of '_get_cooldown_deployments' """ - dt = get_utc_datetime() - current_minute = dt.strftime("%H-%M") - # get the current cooldown list for that minute - cooldown_key = f"{current_minute}:cooldown_models" - - # ---------------------- - # Return cooldown models - # ---------------------- - cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or [] + model_ids = self.get_model_ids() + cooldown_models = await self.cooldown_cache.async_get_active_cooldowns( + model_ids=model_ids + ) cached_value_deployment_ids = [] if ( @@ -3334,15 +3295,10 @@ class Router: """ Async implementation of '_get_cooldown_deployments' """ - dt = get_utc_datetime() - current_minute = dt.strftime("%H-%M") - # get the current cooldown list for that minute - cooldown_key = f"{current_minute}:cooldown_models" - - # ---------------------- - # Return cooldown models - # ---------------------- - cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or [] + model_ids = self.get_model_ids() + cooldown_models = await self.cooldown_cache.async_get_active_cooldowns( + model_ids=model_ids + ) verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models @@ -3351,15 +3307,13 @@ class Router: """ Get the list of models being cooled down for this minute """ - dt = get_utc_datetime() - current_minute = dt.strftime("%H-%M") # get the current cooldown list for that minute - cooldown_key = f"{current_minute}:cooldown_models" # ---------------------- # Return cooldown models # ---------------------- - cooldown_models = self.cache.get_cache(key=cooldown_key) or [] + model_ids = self.get_model_ids() + cooldown_models = self.cooldown_cache.get_active_cooldowns(model_ids=model_ids) cached_value_deployment_ids = [] if ( @@ -3370,7 +3324,6 @@ class Router: ): cached_value_deployment_ids = [cv[0] for cv in cooldown_models] - verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cached_value_deployment_ids def _get_healthy_deployments(self, model: str): @@ -4061,15 +4014,20 @@ class Router: rpm_usage += t return tpm_usage, rpm_usage - def get_model_ids(self) -> List[str]: + def get_model_ids(self, model_name: Optional[str] = None) -> List[str]: """ + if 'model_name' is none, returns all. + Returns list of model id's. """ ids = [] for model in self.model_list: if "model_info" in model and "id" in model["model_info"]: id = model["model_info"]["id"] - ids.append(id) + if model_name is not None and model["model_name"] == model_name: + ids.append(id) + elif model_name is None: + ids.append(id) return ids def get_model_names(self) -> List[str]: @@ -4402,10 +4360,19 @@ class Router: - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check) """ - if _rate_limit_error == True: # allow generic fallback logic to take place - raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds." + if _rate_limit_error is True: # allow generic fallback logic to take place + model_ids = self.get_model_ids(model_name=model) + cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids ) + cooldown_list = self._get_cooldown_deployments() + raise RouterRateLimitError( + model=model, + cooldown_time=cooldown_time, + enable_pre_call_checks=True, + cooldown_list=cooldown_list, + ) + elif _context_window_error is True: raise litellm.ContextWindowExceededError( message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format( @@ -4514,8 +4481,14 @@ class Router: litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") if len(healthy_deployments) == 0: - raise ValueError( - f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds" + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) + _cooldown_list = self._get_cooldown_deployments() + raise RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=self.enable_pre_call_checks, + cooldown_list=_cooldown_list, ) if litellm.model_alias_map and model in litellm.model_alias_map: @@ -4602,8 +4575,16 @@ class Router: if len(healthy_deployments) == 0: if _allowed_model_region is None: _allowed_model_region = "n/a" - raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}" + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids + ) + _cooldown_list = self._get_cooldown_deployments() + raise RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=self.enable_pre_call_checks, + cooldown_list=_cooldown_list, ) if ( @@ -4682,8 +4663,16 @@ class Router: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" ) - raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}" + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids + ) + _cooldown_list = self._get_cooldown_deployments() + raise RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=self.enable_pre_call_checks, + cooldown_list=_cooldown_list, ) verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" @@ -4755,7 +4744,8 @@ class Router: ) if len(healthy_deployments) == 0: - _cooldown_time = self.cooldown_time # [TODO] Make dynamic + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) _cooldown_list = self._get_cooldown_deployments() raise RouterRateLimitError( model=model, @@ -4841,8 +4831,14 @@ class Router: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" ) - raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}" + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) + _cooldown_list = self._get_cooldown_deployments() + raise RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=self.enable_pre_call_checks, + cooldown_list=_cooldown_list, ) verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py new file mode 100644 index 0000000000..50a23e5306 --- /dev/null +++ b/litellm/router_utils/cooldown_cache.py @@ -0,0 +1,152 @@ +""" +Wrapper around router cache. Meant to handle model cooldown logic +""" + +import json +import time +from typing import List, Optional, Tuple, TypedDict + +from litellm import verbose_logger +from litellm.caching import DualCache + + +class CooldownCacheValue(TypedDict): + exception_received: str + status_code: str + timestamp: float + cooldown_time: float + + +class CooldownCache: + def __init__(self, cache: DualCache, default_cooldown_time: float): + self.cache = cache + self.default_cooldown_time = default_cooldown_time + + def _common_add_cooldown_logic( + self, model_id: str, original_exception, exception_status, cooldown_time: float + ) -> Tuple[str, dict]: + try: + current_time = time.time() + cooldown_key = f"deployment:{model_id}:cooldown" + + # Store the cooldown information for the deployment separately + cooldown_data = CooldownCacheValue( + exception_received=str(original_exception), + status_code=str(exception_status), + timestamp=current_time, + cooldown_time=cooldown_time, + ) + + return cooldown_key, cooldown_data + except Exception as e: + verbose_logger.error( + "CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format( + str(e) + ) + ) + raise e + + def add_deployment_to_cooldown( + self, + model_id: str, + original_exception: Exception, + exception_status: int, + cooldown_time: Optional[float], + ): + try: + _cooldown_time = cooldown_time or self.default_cooldown_time + cooldown_key, cooldown_data = self._common_add_cooldown_logic( + model_id=model_id, + original_exception=original_exception, + exception_status=exception_status, + cooldown_time=_cooldown_time, + ) + + # Set the cache with a TTL equal to the cooldown time + self.cache.set_cache( + value=cooldown_data, + key=cooldown_key, + ttl=_cooldown_time, + ) + except Exception as e: + verbose_logger.error( + "CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format( + str(e) + ) + ) + raise e + + async def async_add_deployment_to_cooldown( + self, + model_id: str, + original_exception: Exception, + exception_status: int, + cooldown_time: Optional[float], + ): + cooldown_key, cooldown_data = self._common_add_cooldown_logic( + model_id=model_id, original_exception=original_exception + ) + + # Set the cache with a TTL equal to the cooldown time + self.cache.set_cache( + value=cooldown_data, + key=cooldown_key, + ttl=cooldown_time or self.default_cooldown_time, + ) + + async def async_get_active_cooldowns( + self, model_ids: List[str] + ) -> List[Tuple[str, dict]]: + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + + # Retrieve the values for the keys using mget + results = await self.cache.async_batch_get_cache(keys=keys) + + active_cooldowns = [] + # Process the results + for model_id, result in zip(model_ids, results): + if result: + active_cooldowns.append((model_id, result)) + + return active_cooldowns + + def get_active_cooldowns(self, model_ids: List[str]) -> List[Tuple[str, dict]]: + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + + # Retrieve the values for the keys using mget + results = self.cache.batch_get_cache(keys=keys) + + active_cooldowns = [] + # Process the results + for model_id, result in zip(model_ids, results): + if result: + active_cooldowns.append((model_id, result)) + + return active_cooldowns + + def get_min_cooldown(self, model_ids: List[str]) -> float: + """Return min cooldown time required for a group of model id's.""" + + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + + # Retrieve the values for the keys using mget + results = self.cache.batch_get_cache(keys=keys) + + min_cooldown_time = self.default_cooldown_time + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) + if cooldown_cache_value["cooldown_time"] < min_cooldown_time: + min_cooldown_time = cooldown_cache_value["cooldown_time"] + + return min_cooldown_time + + +# Usage example: +# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time) +# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status) +# active_cooldowns = cooldown_cache.get_active_cooldowns() diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 6e58a04f9e..d3d5cad74d 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2254,7 +2254,9 @@ def test_router_dynamic_cooldown_correct_retry_after_time(sync_mode): assert response_headers["retry-after"] == cooldown_time -def test_router_dynamic_cooldown_message_retry_time(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_router_dynamic_cooldown_message_retry_time(sync_mode): """ User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds" but Azure says to retry in at most 9s @@ -2294,19 +2296,49 @@ def test_router_dynamic_cooldown_message_retry_time(): ): for _ in range(2): try: + if sync_mode: + router.embedding( + model="text-embedding-ada-002", + input="Hello world!", + client=openai_client, + ) + else: + await router.aembedding( + model="text-embedding-ada-002", + input="Hello world!", + client=openai_client, + ) + except litellm.RateLimitError: + pass + + if sync_mode: + cooldown_deployments = router._get_cooldown_deployments() + else: + cooldown_deployments = await router._async_get_cooldown_deployments() + print( + "Cooldown deployments - {}\n{}".format( + cooldown_deployments, len(cooldown_deployments) + ) + ) + + assert len(cooldown_deployments) > 0 + exception_raised = False + try: + if sync_mode: router.embedding( model="text-embedding-ada-002", input="Hello world!", client=openai_client, ) - except litellm.RateLimitError: - pass - - try: - router.embedding( - model="text-embedding-ada-002", - input="Hello world!", - client=openai_client, - ) + else: + await router.aembedding( + model="text-embedding-ada-002", + input="Hello world!", + client=openai_client, + ) except litellm.types.router.RouterRateLimitError as e: + print(e) + exception_raised = True assert e.cooldown_time == cooldown_time + + assert exception_raised