mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[Fix-Router] Don't cooldown when only 1 deployment exists (#5673)
* fix get model list * fix test custom callback router * fix embedding fallback test * fix router retry policy on AuthErrors * fix router test * add test for single deployments no cooldown test prod * add test test_single_deployment_no_cooldowns_test_prod_mock_completion_calls
This commit is contained in:
parent
40c52f9263
commit
e7c22f63e7
4 changed files with 128 additions and 17 deletions
|
@ -6,11 +6,11 @@ model_list:
|
||||||
vertex_project: "adroit-crow-413218"
|
vertex_project: "adroit-crow-413218"
|
||||||
vertex_location: "us-central1"
|
vertex_location: "us-central1"
|
||||||
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"
|
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: fake-azure-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: openai/429
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
|
|
@ -1130,7 +1130,7 @@ class Router:
|
||||||
make_request = False
|
make_request = False
|
||||||
|
|
||||||
while curr_time < end_time:
|
while curr_time < end_time:
|
||||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
_healthy_deployments, _ = await self._async_get_healthy_deployments(
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
||||||
|
@ -3060,14 +3060,17 @@ class Router:
|
||||||
Retry Logic
|
Retry Logic
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
_healthy_deployments, _all_deployments = (
|
||||||
model=kwargs.get("model") or "",
|
await self._async_get_healthy_deployments(
|
||||||
|
model=kwargs.get("model") or "",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# raises an exception if this error should not be retries
|
# raises an exception if this error should not be retries
|
||||||
self.should_retry_this_error(
|
self.should_retry_this_error(
|
||||||
error=e,
|
error=e,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
|
all_deployments=_all_deployments,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
regular_fallbacks=fallbacks,
|
regular_fallbacks=fallbacks,
|
||||||
content_policy_fallbacks=content_policy_fallbacks,
|
content_policy_fallbacks=content_policy_fallbacks,
|
||||||
|
@ -3114,7 +3117,7 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
_healthy_deployments, _ = await self._async_get_healthy_deployments(
|
||||||
model=kwargs.get("model"),
|
model=kwargs.get("model"),
|
||||||
)
|
)
|
||||||
_timeout = self._time_to_sleep_before_retry(
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
|
@ -3135,6 +3138,7 @@ class Router:
|
||||||
self,
|
self,
|
||||||
error: Exception,
|
error: Exception,
|
||||||
healthy_deployments: Optional[List] = None,
|
healthy_deployments: Optional[List] = None,
|
||||||
|
all_deployments: Optional[List] = None,
|
||||||
context_window_fallbacks: Optional[List] = None,
|
context_window_fallbacks: Optional[List] = None,
|
||||||
content_policy_fallbacks: Optional[List] = None,
|
content_policy_fallbacks: Optional[List] = None,
|
||||||
regular_fallbacks: Optional[List] = None,
|
regular_fallbacks: Optional[List] = None,
|
||||||
|
@ -3150,6 +3154,9 @@ class Router:
|
||||||
_num_healthy_deployments = 0
|
_num_healthy_deployments = 0
|
||||||
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
||||||
_num_healthy_deployments = len(healthy_deployments)
|
_num_healthy_deployments = len(healthy_deployments)
|
||||||
|
_num_all_deployments = 0
|
||||||
|
if all_deployments is not None and isinstance(all_deployments, list):
|
||||||
|
_num_all_deployments = len(all_deployments)
|
||||||
|
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
|
||||||
if (
|
if (
|
||||||
|
@ -3180,7 +3187,9 @@ class Router:
|
||||||
- if other deployments available -> retry
|
- if other deployments available -> retry
|
||||||
- else -> raise error
|
- else -> raise error
|
||||||
"""
|
"""
|
||||||
if _num_healthy_deployments <= 0: # if no healthy deployments
|
if (
|
||||||
|
_num_all_deployments <= 1
|
||||||
|
): # if there is only 1 deployment for this model group then don't retry
|
||||||
raise error # then raise error
|
raise error # then raise error
|
||||||
|
|
||||||
# Do not retry if there are no healthy deployments
|
# Do not retry if there are no healthy deployments
|
||||||
|
@ -3390,7 +3399,7 @@ class Router:
|
||||||
current_attempt = None
|
current_attempt = None
|
||||||
original_exception = e
|
original_exception = e
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||||
_healthy_deployments = self._get_healthy_deployments(
|
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
||||||
model=kwargs.get("model"),
|
model=kwargs.get("model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3398,6 +3407,7 @@ class Router:
|
||||||
self.should_retry_this_error(
|
self.should_retry_this_error(
|
||||||
error=e,
|
error=e,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
|
all_deployments=_all_deployments,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
regular_fallbacks=fallbacks,
|
regular_fallbacks=fallbacks,
|
||||||
content_policy_fallbacks=content_policy_fallbacks,
|
content_policy_fallbacks=content_policy_fallbacks,
|
||||||
|
@ -3428,7 +3438,7 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
_healthy_deployments = self._get_healthy_deployments(
|
_healthy_deployments, _ = self._get_healthy_deployments(
|
||||||
model=kwargs.get("model"),
|
model=kwargs.get("model"),
|
||||||
)
|
)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
|
@ -3881,7 +3891,7 @@ class Router:
|
||||||
else:
|
else:
|
||||||
healthy_deployments.append(deployment)
|
healthy_deployments.append(deployment)
|
||||||
|
|
||||||
return healthy_deployments
|
return healthy_deployments, _all_deployments
|
||||||
|
|
||||||
async def _async_get_healthy_deployments(self, model: str):
|
async def _async_get_healthy_deployments(self, model: str):
|
||||||
_all_deployments: list = []
|
_all_deployments: list = []
|
||||||
|
@ -3901,7 +3911,7 @@ class Router:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
healthy_deployments.append(deployment)
|
healthy_deployments.append(deployment)
|
||||||
return healthy_deployments
|
return healthy_deployments, _all_deployments
|
||||||
|
|
||||||
def routing_strategy_pre_call_checks(self, deployment: dict):
|
def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -4679,10 +4689,7 @@ class Router:
|
||||||
returned_models += self.model_list
|
returned_models += self.model_list
|
||||||
|
|
||||||
return returned_models
|
return returned_models
|
||||||
|
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||||
for model in self.model_list:
|
|
||||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
|
||||||
|
|
||||||
return returned_models
|
return returned_models
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -533,6 +533,7 @@ async def test_async_chat_azure_with_fallbacks():
|
||||||
try:
|
try:
|
||||||
customHandler_fallbacks = CompletionCustomHandler()
|
customHandler_fallbacks = CompletionCustomHandler()
|
||||||
litellm.callbacks = [customHandler_fallbacks]
|
litellm.callbacks = [customHandler_fallbacks]
|
||||||
|
litellm.set_verbose = True
|
||||||
# with fallbacks
|
# with fallbacks
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
|
@ -555,7 +556,13 @@ async def test_async_chat_azure_with_fallbacks():
|
||||||
"rpm": 1800,
|
"rpm": 1800,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
||||||
|
retry_policy=litellm.router.RetryPolicy(
|
||||||
|
AuthenticationErrorRetries=0,
|
||||||
|
),
|
||||||
|
) # type: ignore
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||||
|
|
|
@ -150,3 +150,100 @@ def test_single_deployment_no_cooldowns(num_deployments):
|
||||||
mock_client.assert_not_called()
|
mock_client.assert_not_called()
|
||||||
else:
|
else:
|
||||||
mock_client.assert_called_once()
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_deployment_no_cooldowns_test_prod():
|
||||||
|
"""
|
||||||
|
Do not cooldown on single deployment.
|
||||||
|
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-12",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-12",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
allowed_fails=0,
|
||||||
|
num_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
mock_response="litellm.RateLimitError",
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
mock_client.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_deployment_no_cooldowns_test_prod_mock_completion_calls():
|
||||||
|
"""
|
||||||
|
Do not cooldown on single deployment.
|
||||||
|
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-12",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-12",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(20):
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
mock_response="litellm.RateLimitError",
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
cooldown_list = await router._async_get_cooldown_deployments()
|
||||||
|
assert len(cooldown_list) == 0
|
||||||
|
|
||||||
|
healthy_deployments, _ = await router._async_get_healthy_deployments(
|
||||||
|
model="gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("healthy_deployments: ", healthy_deployments)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue