fix(router.py): don't cooldown on apiconnectionerrors

Fixes issue where model would be in cooldown due to api connection errors
This commit is contained in:
Krrish Dholakia 2024-08-24 09:53:05 -07:00
parent 8782ee444d
commit 0b06a76cf9
3 changed files with 78 additions and 5 deletions

View file

@ -897,7 +897,9 @@ def completion(
except Exception as e:
if isinstance(e, VertexAIError):
raise e
raise VertexAIError(status_code=500, message=str(e))
raise litellm.APIConnectionError(
message=str(e), llm_provider="vertex_ai", model=model
)
async def async_completion(

View file

@ -3081,7 +3081,9 @@ class Router:
key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl
def _is_cooldown_required(self, exception_status: Union[str, int]):
def _is_cooldown_required(
self, exception_status: Union[str, int], exception_str: Optional[str] = None
):
"""
A function to determine if a cooldown is required based on the exception status.
@ -3092,6 +3094,13 @@ class Router:
bool: True if a cooldown is required, False otherwise.
"""
try:
ignored_strings = ["APIConnectionError"]
if (
exception_str is not None
): # don't cooldown on litellm api connection errors errors
for ignored_string in ignored_strings:
if ignored_string in exception_str:
return False
if isinstance(exception_status, str):
exception_status = int(exception_status)
@ -3177,7 +3186,12 @@ class Router:
if deployment is None:
return
if self._is_cooldown_required(exception_status=exception_status) == False:
if (
self._is_cooldown_required(
exception_status=exception_status, exception_str=str(original_exception)
)
is False
):
return
if deployment in self.provider_default_deployment_ids:
@ -4418,7 +4432,7 @@ class Router:
- List, if multiple models chosen
"""
# check if aliases set on litellm model alias map
if specific_deployment == True:
if specific_deployment is True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list:
deployment_model = deployment.get("litellm_params").get("model")
@ -4492,6 +4506,7 @@ class Router:
raise ValueError(
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
)
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
@ -4722,7 +4737,15 @@ class Router:
# filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
)
if len(healthy_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, cooldown_list={self._get_cooldown_deployments()}"
)
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:

View file

@ -2107,3 +2107,51 @@ def test_router_context_window_pre_call_check(model, base_model, llm_provider):
pass
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
def test_router_cooldown_api_connection_error():
# try:
# _ = litellm.completion(
# model="vertex_ai/gemini-1.5-pro",
# messages=[{"role": "admin", "content": "Fail on this!"}],
# )
# except litellm.APIConnectionError as e:
# assert (
# Router()._is_cooldown_required(
# exception_status=e.code, exception_str=str(e)
# )
# is False
# )
router = Router(
model_list=[
{
"model_name": "gemini-1.5-pro",
"litellm_params": {"model": "vertex_ai/gemini-1.5-pro"},
}
]
)
try:
router.completion(
model="gemini-1.5-pro",
messages=[{"role": "admin", "content": "Fail on this!"}],
)
except litellm.APIConnectionError:
pass
try:
router.completion(
model="gemini-1.5-pro",
messages=[{"role": "admin", "content": "Fail on this!"}],
)
except litellm.APIConnectionError:
pass
try:
router.completion(
model="gemini-1.5-pro",
messages=[{"role": "admin", "content": "Fail on this!"}],
)
except litellm.APIConnectionError:
pass