fix(router.py): add support for cooldowns with redis

This commit is contained in:
Krrish Dholakia 2023-11-22 19:54:15 -08:00
parent cb41b14cc2
commit 5d5ca9f7ef
3 changed files with 161 additions and 121 deletions

View file

@ -83,14 +83,14 @@ class Router:
if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing
## USAGE TRACKING ##
if type(litellm.success_callback) == list:
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
else:
litellm.success_callback = [self.deployment_callback]
if type(litellm.failure_callback) == list:
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
@ -169,14 +169,12 @@ class Router:
current_time = time.time()
iter = 0
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
### FIND UNHEALTHY DEPLOYMENTS
for deployment in healthy_deployments:
deployment_name = deployment["litellm_params"]["model"]
if deployment_name in self.cooldown_deployments:
if current_time >= self.cooldown_deployments[deployment_name] + 60:
self.cooldown_deployments.pop(deployment_name)
else:
deployments_to_remove.append(deployment)
if deployment_name in cooldown_deployments:
deployments_to_remove.append(deployment)
iter += 1
### FILTER OUT UNHEALTHY DEPLOYMENTS
for deployment in deployments_to_remove:
@ -245,36 +243,31 @@ class Router:
raise e
def function_with_retries(self, *args, **kwargs):
try:
import tenacity
except Exception as e:
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
retry_info = {"attempts": 0, "final_result": None}
# we'll backoff exponentially with each retry
backoff_factor = 1
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
for current_attempt in range(self.num_retries):
self.num_retries -= 1 # decrement the number of retries
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
def after_callback(retry_state):
retry_info["attempts"] = retry_state.attempt_number
retry_info["final_result"] = retry_state.outcome.result()
except openai.RateLimitError as e:
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
if 'model' not in kwargs or 'messages' not in kwargs:
raise ValueError("'model' and 'messages' must be included as keyword arguments")
try:
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
if isinstance(original_exception, openai.RateLimitError):
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(self.num_retries),
reraise=True,
after=after_callback)
elif isinstance(original_exception, openai.APIError):
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(self.num_retries),
reraise=True,
after=after_callback)
return retryer(original_function, *args, **kwargs)
except Exception as e:
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
# increase backoff factor for next run
backoff_factor *= 2
except openai.APIError as e:
# on APIError we immediately retry without any wait, change this if necessary
pass
except Exception as e:
# for any other exception types, don't retry
raise e
### COMPLETION + EMBEDDING FUNCTIONS
@ -422,7 +415,48 @@ class Router:
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
self.cooldown_deployments[model_name] = time.time() # put deployment in cooldown mode
self._set_cooldown_deployments(model_name)
def _set_cooldown_deployments(self,
deployment: str):
"""
Add a model to the list of models being cooled down for that minute
"""
current_minute = datetime.now().strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(cache_key=cooldown_key)
# update value
try:
if deployment in cached_value:
pass
else:
cached_value = cached_value + [deployment]
# save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
except:
cached_value = [deployment]
# save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
def _get_cooldown_deployments(self):
"""
Get the list of models being cooled down for this minute
"""
current_minute = datetime.now().strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or []
return cooldown_models
def get_usage_based_available_deployment(self,
model: str,