From 67728bfccf0734959510633e0a05ea9429b3fc5f Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Mon, 15 Sep 2025 12:57:02 +0200 Subject: [PATCH] Update vLLM health check to use /health endpoint - Replace models.list() call with HTTP GET to /health endpoint - Remove API token validation since /health is unauthenticated - Use urllib.parse.urljoin for cleaner URL construction - Update tests to mock httpx.AsyncClient instead of OpenAI client - Health check now works regardless of API token configuration Signed-off-by: Akram Ben Aissi --- .../providers/remote/inference/vllm/vllm.py | 23 +++--- .../providers/inference/test_remote_vllm.py | 81 ++++++++----------- 2 files changed, 48 insertions(+), 56 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d1ed6365a..a83ec74a3 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -6,6 +6,7 @@ import json from collections.abc import AsyncGenerator, AsyncIterator from typing import Any +from urllib.parse import urljoin import httpx from openai import APIConnectionError, AsyncOpenAI @@ -316,6 +317,10 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro ) async def should_refresh_models(self) -> bool: + # Get the default value from the field definition + default_api_token = self.config.__class__.model_fields["api_token"].default + if not self.config.api_token or self.config.api_token == default_api_token: + return False return self.config.refresh_models async def list_models(self) -> list[Model] | None: @@ -344,21 +349,19 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro Performs a health check by verifying connectivity to the remote vLLM server. This method is used by the Provider API to verify that the service is running correctly. - Only performs the test when a static API key is provided. + Uses the unauthenticated /health endpoint. Returns: HealthResponse: A dictionary containing the health status. """ - # Get the default value from the field definition - default_api_token = self.config.__class__.model_fields["api_token"].default - - # Only perform health check if static API key is provided - if not self.config.api_token or self.config.api_token == default_api_token: - return HealthResponse(status=HealthStatus.OK) - try: - _ = [m async for m in self.client.models.list()] # Ensure the client is initialized - return HealthResponse(status=HealthStatus.OK) + base_url = self.get_base_url() + health_url = urljoin(base_url, "health") + + async with httpx.AsyncClient() as client: + response = await client.get(health_url) + response.raise_for_status() + return HealthResponse(status=HealthStatus.OK) except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 96a57c3c8..b4f5e9cf3 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -563,34 +563,29 @@ async def test_health_status_success(vllm_inference_adapter): """ Test the health method of VLLM InferenceAdapter when the connection is successful. - This test verifies that the health method returns a HealthResponse with status OK, only - when the connection to the vLLM server is successful. + This test verifies that the health method returns a HealthResponse with status OK + when the /health endpoint responds successfully. """ - # Set a non-default API token to enable health check - vllm_inference_adapter.config.api_token = "real-api-key" + with patch("httpx.AsyncClient") as mock_client_class: + # Create mock response + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None - with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: - # Create mock client and models - mock_client = MagicMock() - mock_models = MagicMock() - - # Create a mock async iterator that yields a model when iterated - async def mock_list(): - for model in [MagicMock()]: - yield model - - # Set up the models.list to return our mock async iterator - mock_models.list.return_value = mock_list() - mock_client.models = mock_models - mock_create_client.return_value = mock_client + # Create mock client instance + mock_client_instance = MagicMock() + mock_client_instance.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client_instance # Call the health method health_response = await vllm_inference_adapter.health() + # Verify the response assert health_response["status"] == HealthStatus.OK - # Verify that models.list was called - mock_models.list.assert_called_once() + # Verify that the health endpoint was called + mock_client_instance.get.assert_called_once() + call_args = mock_client_instance.get.call_args[0] + assert call_args[0].endswith("/health") async def test_health_status_failure(vllm_inference_adapter): @@ -600,48 +595,42 @@ async def test_health_status_failure(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status ERROR and an appropriate error message when the connection to the vLLM server fails. """ - # Set a non-default API token to enable health check - vllm_inference_adapter.config.api_token = "real-api-key" - - with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: - # Create mock client and models - mock_client = MagicMock() - mock_models = MagicMock() - - # Create a mock async iterator that raises an exception when iterated - async def mock_list(): - raise Exception("Connection failed") - yield # Unreachable code - - # Set up the models.list to return our mock async iterator - mock_models.list.return_value = mock_list() - mock_client.models = mock_models - mock_create_client.return_value = mock_client + with patch("httpx.AsyncClient") as mock_client_class: + # Create mock client instance that raises an exception + mock_client_instance = MagicMock() + mock_client_instance.get.side_effect = Exception("Connection failed") + mock_client_class.return_value.__aenter__.return_value = mock_client_instance # Call the health method health_response = await vllm_inference_adapter.health() + # Verify the response assert health_response["status"] == HealthStatus.ERROR assert "Health check failed: Connection failed" in health_response["message"] - mock_models.list.assert_called_once() - async def test_health_status_no_static_api_key(vllm_inference_adapter): """ Test the health method of VLLM InferenceAdapter when no static API key is provided. This test verifies that the health method returns a HealthResponse with status OK - without performing any connectivity test when no static API key is provided. + when the /health endpoint responds successfully, regardless of API token configuration. """ - # Ensure api_token is the default value (no static API key) - vllm_inference_adapter.config.api_token = "fake" + with patch("httpx.AsyncClient") as mock_client_class: + # Create mock response + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None - # Call the health method - health_response = await vllm_inference_adapter.health() + # Create mock client instance + mock_client_instance = MagicMock() + mock_client_instance.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client_instance - # Verify the response - assert health_response["status"] == HealthStatus.OK + # Call the health method + health_response = await vllm_inference_adapter.health() + + # Verify the response + assert health_response["status"] == HealthStatus.OK async def test_openai_chat_completion_is_async(vllm_inference_adapter):