diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index 763bd9105..616ea9f7e 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -602,7 +602,7 @@ class InferenceRouter(Inference): async def health(self) -> dict[str, HealthResponse]: health_statuses = {} - timeout = 0.5 + timeout = 1 # increasing the timeout to 1 second for health checks for provider_id, impl in self.routing_table.impls_by_provider_id.items(): try: # check if the provider has a health method diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index b703c07fc..d0a822f3c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -313,10 +313,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): """ try: client = self._create_client() if self.client is None else self.client - client.models.list() # Ensure the client is initialized + _ = [m async for m in client.models.list()] # Ensure the client is initialized return HealthResponse(status=HealthStatus.OK) - except Exception as ex: - return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}") + except Exception as e: + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def _get_model(self, model_id: str) -> Model: if not self.model_store: diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 1fc68a631..c6b62e3e7 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -650,15 +650,23 @@ async def test_health_status_success(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status OK, only when the connection to the vLLM server is successful. """ - # Mock the client.models.list method to return successfully # Set vllm_inference_adapter.client to None to ensure _create_client is called vllm_inference_adapter.client = None with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: # Create mock client and models mock_client = MagicMock() mock_models = MagicMock() + + # Create a mock async iterator that yields a model when iterated + async def mock_list(): + for model in [MagicMock()]: + yield model + + # Set up the models.list to return our mock async iterator + mock_models.list.return_value = mock_list() mock_client.models = mock_models mock_create_client.return_value = mock_client + # Call the health method health_response = await vllm_inference_adapter.health() # Verify the response @@ -677,14 +685,21 @@ async def test_health_status_failure(vllm_inference_adapter): and an appropriate error message when the connection to the vLLM server fails. """ vllm_inference_adapter.client = None - # Mock the client.models.list method to raise an exception with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: # Create mock client and models mock_client = MagicMock() mock_models = MagicMock() - mock_models.list.side_effect = Exception("Connection failed") + + # Create a mock async iterator that raises an exception when iterated + async def mock_list(): + raise Exception("Connection failed") + yield # Unreachable code + + # Set up the models.list to return our mock async iterator + mock_models.list.return_value = mock_list() mock_client.models = mock_models mock_create_client.return_value = mock_client + # Call the health method health_response = await vllm_inference_adapter.health() # Verify the response