fix get healthy deployments

This commit is contained in:
Ishaan Jaff 2024-05-11 19:46:35 -07:00
parent 04ac352407
commit 61a3e5d5a9

View file

@ -1565,7 +1565,7 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_, _healthy_deployments = self._common_checks_available_deployment( _healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
_timeout = self._time_to_sleep_before_retry( _timeout = self._time_to_sleep_before_retry(
@ -1796,7 +1796,7 @@ class Router:
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
_, _healthy_deployments = self._common_checks_available_deployment( _healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model"),
) )
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
@ -2008,12 +2008,18 @@ class Router:
return cooldown_models return cooldown_models
def _get_healthy_deployments(self, model: str): def _get_healthy_deployments(self, model: str):
_, _all_deployments = self._common_checks_available_deployment( _all_deployments: list = []
model=model, try:
) _, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = self._get_cooldown_deployments() unhealthy_deployments = self._get_cooldown_deployments()
healthy_deployments = [] healthy_deployments: list = []
for deployment in _all_deployments: for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments: if deployment["model_info"]["id"] in unhealthy_deployments:
continue continue
@ -2023,12 +2029,18 @@ class Router:
return healthy_deployments return healthy_deployments
async def _async_get_healthy_deployments(self, model: str): async def _async_get_healthy_deployments(self, model: str):
_, _all_deployments = self._common_checks_available_deployment( _all_deployments: list = []
model=model, try:
) _, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = await self._async_get_cooldown_deployments() unhealthy_deployments = await self._async_get_cooldown_deployments()
healthy_deployments = [] healthy_deployments: list = []
for deployment in _all_deployments: for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments: if deployment["model_info"]["id"] in unhealthy_deployments:
continue continue