fix(router.py): fix cooldown check

This commit is contained in:
Krrish Dholakia 2024-08-28 16:38:05 -07:00
parent 25d8cb69a7
commit f0fb8bdf45
3 changed files with 24 additions and 31 deletions

View file

@ -1,19 +1,9 @@
model_list: model_list:
- model_name: fake-openai-endpoint - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: gpt-3.5-turbo model: openai/my-fake-model
# model: sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614 api_key: my-fake-key
# sagemaker_base_url: https://exampleopenaiendpoint-production.up.railway.app/invocations/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
# api_base: https://exampleopenaiendpoint-production.up.railway.app
input_cost_per_token: 10
output_cost_per_token: 10
litellm_settings:
max_internal_user_budget: 0.00001
internal_user_budget_duration: "3s" # reset every 3seconds
general_settings:
proxy_budget_rescheduler_min_time: 1
proxy_budget_rescheduler_max_time: 2

View file

@ -92,6 +92,7 @@ from litellm.types.router import (
RouterErrors, RouterErrors,
RouterGeneralSettings, RouterGeneralSettings,
RouterRateLimitError, RouterRateLimitError,
RouterRateLimitErrorBasic,
updateDeployment, updateDeployment,
updateLiteLLMParams, updateLiteLLMParams,
) )
@ -4459,16 +4460,8 @@ class Router:
""" """
if _rate_limit_error is True: # allow generic fallback logic to take place if _rate_limit_error is True: # allow generic fallback logic to take place
model_ids = self.get_model_ids(model_name=model) raise RouterRateLimitErrorBasic(
cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
)
cooldown_list = self._get_cooldown_deployments()
raise RouterRateLimitError(
model=model, model=model,
cooldown_time=cooldown_time,
enable_pre_call_checks=True,
cooldown_list=cooldown_list,
) )
elif _context_window_error is True: elif _context_window_error is True:
@ -4579,14 +4572,10 @@ class Router:
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
model_ids = self.get_model_ids(model_name=model) raise ValueError(
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) "{}. You passed in model={}. There is no 'model_name' with this string ".format(
_cooldown_list = self._get_cooldown_deployments() RouterErrors.no_deployments_available.value, model
raise RouterRateLimitError( )
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:

View file

@ -551,6 +551,20 @@ class RouterGeneralSettings(BaseModel):
) # if passed a model not llm_router model list, pass through the request to litellm.acompletion/embedding ) # if passed a model not llm_router model list, pass through the request to litellm.acompletion/embedding
class RouterRateLimitErrorBasic(ValueError):
"""
Raise a basic error inside helper functions.
"""
def __init__(
self,
model: str,
):
self.model = model
_message = f"{RouterErrors.no_deployments_available.value}."
super().__init__(_message)
class RouterRateLimitError(ValueError): class RouterRateLimitError(ValueError):
def __init__( def __init__(
self, self,