From 33ecefd284171fdde6814e22796c12a26c1c0284 Mon Sep 17 00:00:00 2001 From: Sumit Jaiswal Date: Sat, 7 Jun 2025 01:03:12 +0530 Subject: [PATCH] feat: To add health status check for remote VLLM (#2303) # What does this PR do? To add health status check for remote VLLM ## Test Plan PR includes the unit test to test the added health check implementation feature. --- llama_stack/distribution/routers/inference.py | 2 +- .../providers/remote/inference/vllm/vllm.py | 22 +++++- .../providers/inference/test_remote_vllm.py | 70 ++++++++++++++++++- 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index 763bd9105..2e111c20a 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -602,7 +602,7 @@ class InferenceRouter(Inference): async def health(self) -> dict[str, HealthResponse]: health_statuses = {} - timeout = 0.5 + timeout = 1 # increasing the timeout to 1 second for health checks for provider_id, impl in self.routing_table.impls_by_provider_id.items(): try: # check if the provider has a health method diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9f38d9abf..d0a822f3c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -56,7 +56,11 @@ from llama_stack.apis.inference.inference import ( from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models -from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + ModelsProtocolPrivate, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, @@ -298,6 +302,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: pass + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the remote vLLM server. + This method is used by the Provider API to verify + that the service is running correctly. + Returns: + + HealthResponse: A dictionary containing the health status. + """ + try: + client = self._create_client() if self.client is None else self.client + _ = [m async for m in client.models.list()] # Ensure the client is initialized + return HealthResponse(status=HealthStatus.OK) + except Exception as e: + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") + async def _get_model(self, model_id: str) -> Model: if not self.model_store: raise ValueError("Model store not set") diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 17c867af1..eaa9b40da 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -11,7 +11,7 @@ import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio @@ -44,6 +44,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models import Model from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.vllm import ( VLLMInferenceAdapter, @@ -642,3 +643,70 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args(): assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.arguments == {} + + +@pytest.mark.asyncio +async def test_health_status_success(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when the connection is successful. + + This test verifies that the health method returns a HealthResponse with status OK, only + when the connection to the vLLM server is successful. + """ + # Set vllm_inference_adapter.client to None to ensure _create_client is called + vllm_inference_adapter.client = None + with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: + # Create mock client and models + mock_client = MagicMock() + mock_models = MagicMock() + + # Create a mock async iterator that yields a model when iterated + async def mock_list(): + for model in [MagicMock()]: + yield model + + # Set up the models.list to return our mock async iterator + mock_models.list.return_value = mock_list() + mock_client.models = mock_models + mock_create_client.return_value = mock_client + + # Call the health method + health_response = await vllm_inference_adapter.health() + # Verify the response + assert health_response["status"] == HealthStatus.OK + + # Verify that models.list was called + mock_models.list.assert_called_once() + + +@pytest.mark.asyncio +async def test_health_status_failure(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when the connection fails. + + This test verifies that the health method returns a HealthResponse with status ERROR + and an appropriate error message when the connection to the vLLM server fails. + """ + vllm_inference_adapter.client = None + with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: + # Create mock client and models + mock_client = MagicMock() + mock_models = MagicMock() + + # Create a mock async iterator that raises an exception when iterated + async def mock_list(): + raise Exception("Connection failed") + yield # Unreachable code + + # Set up the models.list to return our mock async iterator + mock_models.list.return_value = mock_list() + mock_client.models = mock_models + mock_create_client.return_value = mock_client + + # Call the health method + health_response = await vllm_inference_adapter.health() + # Verify the response + assert health_response["status"] == HealthStatus.ERROR + assert "Health check failed: Connection failed" in health_response["message"] + + mock_models.list.assert_called_once()