[Fix-Router] Don't cooldown when only 1 deployment exists (#5673)

* fix get model list

* fix test custom callback router

* fix embedding fallback test

* fix router retry policy on AuthErrors

* fix router test

* add test for single deployments no cooldown test prod

* add test test_single_deployment_no_cooldowns_test_prod_mock_completion_calls
This commit is contained in:
Ishaan Jaff 2024-09-12 19:14:58 -07:00 committed by GitHub
parent 40c52f9263
commit e7c22f63e7
4 changed files with 128 additions and 17 deletions

View file

@ -6,11 +6,11 @@ model_list:
vertex_project: "adroit-crow-413218" vertex_project: "adroit-crow-413218"
vertex_location: "us-central1" vertex_location: "us-central1"
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"
- model_name: fake-openai-endpoint - model_name: fake-azure-endpoint
litellm_params: litellm_params:
model: openai/fake model: openai/429
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -1130,7 +1130,7 @@ class Router:
make_request = False make_request = False
while curr_time < end_time: while curr_time < end_time:
_healthy_deployments = await self._async_get_healthy_deployments( _healthy_deployments, _ = await self._async_get_healthy_deployments(
model=model model=model
) )
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
@ -3060,14 +3060,17 @@ class Router:
Retry Logic Retry Logic
""" """
_healthy_deployments = await self._async_get_healthy_deployments( _healthy_deployments, _all_deployments = (
await self._async_get_healthy_deployments(
model=kwargs.get("model") or "", model=kwargs.get("model") or "",
) )
)
# raises an exception if this error should not be retries # raises an exception if this error should not be retries
self.should_retry_this_error( self.should_retry_this_error(
error=e, error=e,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks, regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks, content_policy_fallbacks=content_policy_fallbacks,
@ -3114,7 +3117,7 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_healthy_deployments = await self._async_get_healthy_deployments( _healthy_deployments, _ = await self._async_get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
_timeout = self._time_to_sleep_before_retry( _timeout = self._time_to_sleep_before_retry(
@ -3135,6 +3138,7 @@ class Router:
self, self,
error: Exception, error: Exception,
healthy_deployments: Optional[List] = None, healthy_deployments: Optional[List] = None,
all_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None, context_window_fallbacks: Optional[List] = None,
content_policy_fallbacks: Optional[List] = None, content_policy_fallbacks: Optional[List] = None,
regular_fallbacks: Optional[List] = None, regular_fallbacks: Optional[List] = None,
@ -3150,6 +3154,9 @@ class Router:
_num_healthy_deployments = 0 _num_healthy_deployments = 0
if healthy_deployments is not None and isinstance(healthy_deployments, list): if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments) _num_healthy_deployments = len(healthy_deployments)
_num_all_deployments = 0
if all_deployments is not None and isinstance(all_deployments, list):
_num_all_deployments = len(all_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
if ( if (
@ -3180,7 +3187,9 @@ class Router:
- if other deployments available -> retry - if other deployments available -> retry
- else -> raise error - else -> raise error
""" """
if _num_healthy_deployments <= 0: # if no healthy deployments if (
_num_all_deployments <= 1
): # if there is only 1 deployment for this model group then don't retry
raise error # then raise error raise error # then raise error
# Do not retry if there are no healthy deployments # Do not retry if there are no healthy deployments
@ -3390,7 +3399,7 @@ class Router:
current_attempt = None current_attempt = None
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_healthy_deployments = self._get_healthy_deployments( _healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
@ -3398,6 +3407,7 @@ class Router:
self.should_retry_this_error( self.should_retry_this_error(
error=e, error=e,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks, regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks, content_policy_fallbacks=content_policy_fallbacks,
@ -3428,7 +3438,7 @@ class Router:
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
_healthy_deployments = self._get_healthy_deployments( _healthy_deployments, _ = self._get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
@ -3881,7 +3891,7 @@ class Router:
else: else:
healthy_deployments.append(deployment) healthy_deployments.append(deployment)
return healthy_deployments return healthy_deployments, _all_deployments
async def _async_get_healthy_deployments(self, model: str): async def _async_get_healthy_deployments(self, model: str):
_all_deployments: list = [] _all_deployments: list = []
@ -3901,7 +3911,7 @@ class Router:
continue continue
else: else:
healthy_deployments.append(deployment) healthy_deployments.append(deployment)
return healthy_deployments return healthy_deployments, _all_deployments
def routing_strategy_pre_call_checks(self, deployment: dict): def routing_strategy_pre_call_checks(self, deployment: dict):
""" """
@ -4679,10 +4689,7 @@ class Router:
returned_models += self.model_list returned_models += self.model_list
return returned_models return returned_models
for model in self.model_list:
returned_models.extend(self._get_all_deployments(model_name=model_name)) returned_models.extend(self._get_all_deployments(model_name=model_name))
return returned_models return returned_models
return None return None

View file

@ -533,6 +533,7 @@ async def test_async_chat_azure_with_fallbacks():
try: try:
customHandler_fallbacks = CompletionCustomHandler() customHandler_fallbacks = CompletionCustomHandler()
litellm.callbacks = [customHandler_fallbacks] litellm.callbacks = [customHandler_fallbacks]
litellm.set_verbose = True
# with fallbacks # with fallbacks
model_list = [ model_list = [
{ {
@ -555,7 +556,13 @@ async def test_async_chat_azure_with_fallbacks():
"rpm": 1800, "rpm": 1800,
}, },
] ]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore router = Router(
model_list=model_list,
fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
retry_policy=litellm.router.RetryPolicy(
AuthenticationErrorRetries=0,
),
) # type: ignore
response = await router.acompletion( response = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],

View file

@ -150,3 +150,100 @@ def test_single_deployment_no_cooldowns(num_deployments):
mock_client.assert_not_called() mock_client.assert_not_called()
else: else:
mock_client.assert_called_once() mock_client.assert_called_once()
@pytest.mark.asyncio
async def test_single_deployment_no_cooldowns_test_prod():
"""
Do not cooldown on single deployment.
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-5",
},
},
{
"model_name": "gpt-12",
"litellm_params": {
"model": "openai/gpt-12",
},
},
],
allowed_fails=0,
num_retries=0,
)
with patch.object(
router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock()
) as mock_client:
try:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
await asyncio.sleep(2)
mock_client.assert_not_called()
@pytest.mark.asyncio
async def test_single_deployment_no_cooldowns_test_prod_mock_completion_calls():
"""
Do not cooldown on single deployment.
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-5",
},
},
{
"model_name": "gpt-12",
"litellm_params": {
"model": "openai/gpt-12",
},
},
],
)
for _ in range(20):
try:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
cooldown_list = await router._async_get_cooldown_deployments()
assert len(cooldown_list) == 0
healthy_deployments, _ = await router._async_get_healthy_deployments(
model="gpt-3.5-turbo"
)
print("healthy_deployments: ", healthy_deployments)