mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(router.py): enable dynamic retry after in exception string
Updates cooldown logic to cooldown individual models Closes https://github.com/BerriAI/litellm/issues/1339
This commit is contained in:
parent
76834c6c59
commit
33972cc79c
3 changed files with 271 additions and 91 deletions
|
@ -58,6 +58,7 @@ from litellm.router_utils.client_initalization_utils import (
|
|||
set_client,
|
||||
should_initialize_sync_client,
|
||||
)
|
||||
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
|
||||
from litellm.router_utils.fallback_event_handlers import (
|
||||
log_failure_fallback_event,
|
||||
|
@ -338,6 +339,9 @@ class Router:
|
|||
else:
|
||||
self.allowed_fails = litellm.allowed_fails
|
||||
self.cooldown_time = cooldown_time or 60
|
||||
self.cooldown_cache = CooldownCache(
|
||||
cache=self.cache, default_cooldown_time=self.cooldown_time
|
||||
)
|
||||
self.disable_cooldowns = disable_cooldowns
|
||||
self.failed_calls = (
|
||||
InMemoryCache()
|
||||
|
@ -3243,51 +3247,13 @@ class Router:
|
|||
|
||||
if updated_fails > allowed_fails or _should_retry is False:
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
cached_value = self.cache.get_cache(
|
||||
key=cooldown_key
|
||||
) # [(deployment_id, {last_error_str, last_error_status_code})]
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
cached_value is not None
|
||||
and isinstance(cached_value, list)
|
||||
and len(cached_value) > 0
|
||||
and isinstance(cached_value[0], tuple)
|
||||
):
|
||||
cached_value_deployment_ids = [cv[0] for cv in cached_value]
|
||||
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
||||
# update value
|
||||
if cached_value is not None and len(cached_value_deployment_ids) > 0:
|
||||
if deployment in cached_value_deployment_ids:
|
||||
pass
|
||||
else:
|
||||
cached_value = cached_value + [
|
||||
(
|
||||
deployment,
|
||||
{
|
||||
"Exception Received": str(original_exception),
|
||||
"Status Code": str(exception_status),
|
||||
},
|
||||
)
|
||||
]
|
||||
# save updated value
|
||||
self.cache.set_cache(
|
||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||
)
|
||||
else:
|
||||
cached_value = [
|
||||
(
|
||||
deployment,
|
||||
{
|
||||
"Exception Received": str(original_exception),
|
||||
"Status Code": str(exception_status),
|
||||
},
|
||||
)
|
||||
]
|
||||
# save updated value
|
||||
self.cache.set_cache(
|
||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||
self.cooldown_cache.add_deployment_to_cooldown(
|
||||
model_id=deployment,
|
||||
original_exception=original_exception,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
|
||||
# Trigger cooldown handler
|
||||
|
@ -3308,15 +3274,10 @@ class Router:
|
|||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models"
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||
model_ids = self.get_model_ids()
|
||||
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
|
||||
model_ids=model_ids
|
||||
)
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
|
@ -3334,15 +3295,10 @@ class Router:
|
|||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models"
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||
model_ids = self.get_model_ids()
|
||||
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
|
||||
model_ids=model_ids
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
@ -3351,15 +3307,13 @@ class Router:
|
|||
"""
|
||||
Get the list of models being cooled down for this minute
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models"
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
||||
model_ids = self.get_model_ids()
|
||||
cooldown_models = self.cooldown_cache.get_active_cooldowns(model_ids=model_ids)
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
|
@ -3370,7 +3324,6 @@ class Router:
|
|||
):
|
||||
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cached_value_deployment_ids
|
||||
|
||||
def _get_healthy_deployments(self, model: str):
|
||||
|
@ -4061,14 +4014,19 @@ class Router:
|
|||
rpm_usage += t
|
||||
return tpm_usage, rpm_usage
|
||||
|
||||
def get_model_ids(self) -> List[str]:
|
||||
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
if 'model_name' is none, returns all.
|
||||
|
||||
Returns list of model id's.
|
||||
"""
|
||||
ids = []
|
||||
for model in self.model_list:
|
||||
if "model_info" in model and "id" in model["model_info"]:
|
||||
id = model["model_info"]["id"]
|
||||
if model_name is not None and model["model_name"] == model_name:
|
||||
ids.append(id)
|
||||
elif model_name is None:
|
||||
ids.append(id)
|
||||
return ids
|
||||
|
||||
|
@ -4402,10 +4360,19 @@ class Router:
|
|||
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
|
||||
"""
|
||||
|
||||
if _rate_limit_error == True: # allow generic fallback logic to take place
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
|
||||
if _rate_limit_error is True: # allow generic fallback logic to take place
|
||||
model_ids = self.get_model_ids(model_name=model)
|
||||
cooldown_time = self.cooldown_cache.get_min_cooldown(
|
||||
model_ids=model_ids
|
||||
)
|
||||
cooldown_list = self._get_cooldown_deployments()
|
||||
raise RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=cooldown_time,
|
||||
enable_pre_call_checks=True,
|
||||
cooldown_list=cooldown_list,
|
||||
)
|
||||
|
||||
elif _context_window_error is True:
|
||||
raise litellm.ContextWindowExceededError(
|
||||
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
|
||||
|
@ -4514,8 +4481,14 @@ class Router:
|
|||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(
|
||||
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
|
||||
model_ids = self.get_model_ids(model_name=model)
|
||||
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
|
||||
_cooldown_list = self._get_cooldown_deployments()
|
||||
raise RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=_cooldown_time,
|
||||
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||
cooldown_list=_cooldown_list,
|
||||
)
|
||||
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
|
@ -4602,8 +4575,16 @@ class Router:
|
|||
if len(healthy_deployments) == 0:
|
||||
if _allowed_model_region is None:
|
||||
_allowed_model_region = "n/a"
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
|
||||
model_ids = self.get_model_ids(model_name=model)
|
||||
_cooldown_time = self.cooldown_cache.get_min_cooldown(
|
||||
model_ids=model_ids
|
||||
)
|
||||
_cooldown_list = self._get_cooldown_deployments()
|
||||
raise RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=_cooldown_time,
|
||||
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||
cooldown_list=_cooldown_list,
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -4682,8 +4663,16 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
||||
model_ids = self.get_model_ids(model_name=model)
|
||||
_cooldown_time = self.cooldown_cache.get_min_cooldown(
|
||||
model_ids=model_ids
|
||||
)
|
||||
_cooldown_list = self._get_cooldown_deployments()
|
||||
raise RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=_cooldown_time,
|
||||
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||
cooldown_list=_cooldown_list,
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
|
@ -4755,7 +4744,8 @@ class Router:
|
|||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
_cooldown_time = self.cooldown_time # [TODO] Make dynamic
|
||||
model_ids = self.get_model_ids(model_name=model)
|
||||
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
|
||||
_cooldown_list = self._get_cooldown_deployments()
|
||||
raise RouterRateLimitError(
|
||||
model=model,
|
||||
|
@ -4841,8 +4831,14 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
||||
model_ids = self.get_model_ids(model_name=model)
|
||||
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
|
||||
_cooldown_list = self._get_cooldown_deployments()
|
||||
raise RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=_cooldown_time,
|
||||
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||
cooldown_list=_cooldown_list,
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
|
|
152
litellm/router_utils/cooldown_cache.py
Normal file
152
litellm/router_utils/cooldown_cache.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
"""
|
||||
Wrapper around router cache. Meant to handle model cooldown logic
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Tuple, TypedDict
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
|
||||
|
||||
class CooldownCacheValue(TypedDict):
|
||||
exception_received: str
|
||||
status_code: str
|
||||
timestamp: float
|
||||
cooldown_time: float
|
||||
|
||||
|
||||
class CooldownCache:
|
||||
def __init__(self, cache: DualCache, default_cooldown_time: float):
|
||||
self.cache = cache
|
||||
self.default_cooldown_time = default_cooldown_time
|
||||
|
||||
def _common_add_cooldown_logic(
|
||||
self, model_id: str, original_exception, exception_status, cooldown_time: float
|
||||
) -> Tuple[str, dict]:
|
||||
try:
|
||||
current_time = time.time()
|
||||
cooldown_key = f"deployment:{model_id}:cooldown"
|
||||
|
||||
# Store the cooldown information for the deployment separately
|
||||
cooldown_data = CooldownCacheValue(
|
||||
exception_received=str(original_exception),
|
||||
status_code=str(exception_status),
|
||||
timestamp=current_time,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
|
||||
return cooldown_key, cooldown_data
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
def add_deployment_to_cooldown(
|
||||
self,
|
||||
model_id: str,
|
||||
original_exception: Exception,
|
||||
exception_status: int,
|
||||
cooldown_time: Optional[float],
|
||||
):
|
||||
try:
|
||||
_cooldown_time = cooldown_time or self.default_cooldown_time
|
||||
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
|
||||
model_id=model_id,
|
||||
original_exception=original_exception,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=_cooldown_time,
|
||||
)
|
||||
|
||||
# Set the cache with a TTL equal to the cooldown time
|
||||
self.cache.set_cache(
|
||||
value=cooldown_data,
|
||||
key=cooldown_key,
|
||||
ttl=_cooldown_time,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_add_deployment_to_cooldown(
|
||||
self,
|
||||
model_id: str,
|
||||
original_exception: Exception,
|
||||
exception_status: int,
|
||||
cooldown_time: Optional[float],
|
||||
):
|
||||
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
|
||||
model_id=model_id, original_exception=original_exception
|
||||
)
|
||||
|
||||
# Set the cache with a TTL equal to the cooldown time
|
||||
self.cache.set_cache(
|
||||
value=cooldown_data,
|
||||
key=cooldown_key,
|
||||
ttl=cooldown_time or self.default_cooldown_time,
|
||||
)
|
||||
|
||||
async def async_get_active_cooldowns(
|
||||
self, model_ids: List[str]
|
||||
) -> List[Tuple[str, dict]]:
|
||||
# Generate the keys for the deployments
|
||||
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
results = await self.cache.async_batch_get_cache(keys=keys)
|
||||
|
||||
active_cooldowns = []
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result:
|
||||
active_cooldowns.append((model_id, result))
|
||||
|
||||
return active_cooldowns
|
||||
|
||||
def get_active_cooldowns(self, model_ids: List[str]) -> List[Tuple[str, dict]]:
|
||||
# Generate the keys for the deployments
|
||||
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
results = self.cache.batch_get_cache(keys=keys)
|
||||
|
||||
active_cooldowns = []
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result:
|
||||
active_cooldowns.append((model_id, result))
|
||||
|
||||
return active_cooldowns
|
||||
|
||||
def get_min_cooldown(self, model_ids: List[str]) -> float:
|
||||
"""Return min cooldown time required for a group of model id's."""
|
||||
|
||||
# Generate the keys for the deployments
|
||||
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
results = self.cache.batch_get_cache(keys=keys)
|
||||
|
||||
min_cooldown_time = self.default_cooldown_time
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result)
|
||||
if cooldown_cache_value["cooldown_time"] < min_cooldown_time:
|
||||
min_cooldown_time = cooldown_cache_value["cooldown_time"]
|
||||
|
||||
return min_cooldown_time
|
||||
|
||||
|
||||
# Usage example:
|
||||
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
|
||||
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
|
||||
# active_cooldowns = cooldown_cache.get_active_cooldowns()
|
|
@ -2254,7 +2254,9 @@ def test_router_dynamic_cooldown_correct_retry_after_time(sync_mode):
|
|||
assert response_headers["retry-after"] == cooldown_time
|
||||
|
||||
|
||||
def test_router_dynamic_cooldown_message_retry_time():
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_dynamic_cooldown_message_retry_time(sync_mode):
|
||||
"""
|
||||
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
|
||||
but Azure says to retry in at most 9s
|
||||
|
@ -2294,19 +2296,49 @@ def test_router_dynamic_cooldown_message_retry_time():
|
|||
):
|
||||
for _ in range(2):
|
||||
try:
|
||||
if sync_mode:
|
||||
router.embedding(
|
||||
model="text-embedding-ada-002",
|
||||
input="Hello world!",
|
||||
client=openai_client,
|
||||
)
|
||||
else:
|
||||
await router.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input="Hello world!",
|
||||
client=openai_client,
|
||||
)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
if sync_mode:
|
||||
cooldown_deployments = router._get_cooldown_deployments()
|
||||
else:
|
||||
cooldown_deployments = await router._async_get_cooldown_deployments()
|
||||
print(
|
||||
"Cooldown deployments - {}\n{}".format(
|
||||
cooldown_deployments, len(cooldown_deployments)
|
||||
)
|
||||
)
|
||||
|
||||
assert len(cooldown_deployments) > 0
|
||||
exception_raised = False
|
||||
try:
|
||||
if sync_mode:
|
||||
router.embedding(
|
||||
model="text-embedding-ada-002",
|
||||
input="Hello world!",
|
||||
client=openai_client,
|
||||
)
|
||||
else:
|
||||
await router.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input="Hello world!",
|
||||
client=openai_client,
|
||||
)
|
||||
except litellm.types.router.RouterRateLimitError as e:
|
||||
print(e)
|
||||
exception_raised = True
|
||||
assert e.cooldown_time == cooldown_time
|
||||
|
||||
assert exception_raised
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue