mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(router.py): add support for cooldowns with redis
This commit is contained in:
parent
cb41b14cc2
commit
5d5ca9f7ef
3 changed files with 161 additions and 121 deletions
|
@ -83,14 +83,14 @@ class Router:
|
|||
if cache_responses:
|
||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
|
||||
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing
|
||||
## USAGE TRACKING ##
|
||||
if type(litellm.success_callback) == list:
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.deployment_callback)
|
||||
else:
|
||||
litellm.success_callback = [self.deployment_callback]
|
||||
|
||||
if type(litellm.failure_callback) == list:
|
||||
if isinstance(litellm.failure_callback, list):
|
||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||
else:
|
||||
litellm.failure_callback = [self.deployment_callback_on_failure]
|
||||
|
@ -169,14 +169,12 @@ class Router:
|
|||
current_time = time.time()
|
||||
iter = 0
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
### FIND UNHEALTHY DEPLOYMENTS
|
||||
for deployment in healthy_deployments:
|
||||
deployment_name = deployment["litellm_params"]["model"]
|
||||
if deployment_name in self.cooldown_deployments:
|
||||
if current_time >= self.cooldown_deployments[deployment_name] + 60:
|
||||
self.cooldown_deployments.pop(deployment_name)
|
||||
else:
|
||||
deployments_to_remove.append(deployment)
|
||||
if deployment_name in cooldown_deployments:
|
||||
deployments_to_remove.append(deployment)
|
||||
iter += 1
|
||||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
||||
for deployment in deployments_to_remove:
|
||||
|
@ -245,36 +243,31 @@ class Router:
|
|||
raise e
|
||||
|
||||
def function_with_retries(self, *args, **kwargs):
|
||||
try:
|
||||
import tenacity
|
||||
except Exception as e:
|
||||
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
||||
|
||||
retry_info = {"attempts": 0, "final_result": None}
|
||||
# we'll backoff exponentially with each retry
|
||||
backoff_factor = 1
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
for current_attempt in range(self.num_retries):
|
||||
self.num_retries -= 1 # decrement the number of retries
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
|
||||
def after_callback(retry_state):
|
||||
retry_info["attempts"] = retry_state.attempt_number
|
||||
retry_info["final_result"] = retry_state.outcome.result()
|
||||
except openai.RateLimitError as e:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
time.sleep(backoff_factor)
|
||||
|
||||
if 'model' not in kwargs or 'messages' not in kwargs:
|
||||
raise ValueError("'model' and 'messages' must be included as keyword arguments")
|
||||
|
||||
try:
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
if isinstance(original_exception, openai.RateLimitError):
|
||||
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10),
|
||||
stop=tenacity.stop_after_attempt(self.num_retries),
|
||||
reraise=True,
|
||||
after=after_callback)
|
||||
elif isinstance(original_exception, openai.APIError):
|
||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(self.num_retries),
|
||||
reraise=True,
|
||||
after=after_callback)
|
||||
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
|
||||
except openai.APIError as e:
|
||||
# on APIError we immediately retry without any wait, change this if necessary
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, don't retry
|
||||
raise e
|
||||
|
||||
### COMPLETION + EMBEDDING FUNCTIONS
|
||||
|
||||
|
@ -422,7 +415,48 @@ class Router:
|
|||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
self.cooldown_deployments[model_name] = time.time() # put deployment in cooldown mode
|
||||
|
||||
self._set_cooldown_deployments(model_name)
|
||||
|
||||
def _set_cooldown_deployments(self,
|
||||
deployment: str):
|
||||
"""
|
||||
Add a model to the list of models being cooled down for that minute
|
||||
"""
|
||||
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
cached_value = self.cache.get_cache(cache_key=cooldown_key)
|
||||
|
||||
# update value
|
||||
try:
|
||||
if deployment in cached_value:
|
||||
pass
|
||||
else:
|
||||
cached_value = cached_value + [deployment]
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
|
||||
except:
|
||||
cached_value = [deployment]
|
||||
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
|
||||
|
||||
def _get_cooldown_deployments(self):
|
||||
"""
|
||||
Get the list of models being cooled down for this minute
|
||||
"""
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models"
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or []
|
||||
|
||||
return cooldown_models
|
||||
|
||||
def get_usage_based_available_deployment(self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue