Update vLLM health check to use /health endpoint

- Replace models.list() call with HTTP GET to /health endpoint
- Remove API token validation since /health is unauthenticated
- Use urllib.parse.urljoin for cleaner URL construction
- Update tests to mock httpx.AsyncClient instead of OpenAI client
- Health check now works regardless of API token configuration

Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
Akram Ben Aissi 2025-09-15 12:57:02 +02:00
parent 5e74bc7fcf
commit 67728bfccf
2 changed files with 48 additions and 56 deletions

View file

@ -6,6 +6,7 @@
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError, AsyncOpenAI
@ -316,6 +317,10 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
)
async def should_refresh_models(self) -> bool:
# Get the default value from the field definition
default_api_token = self.config.__class__.model_fields["api_token"].default
if not self.config.api_token or self.config.api_token == default_api_token:
return False
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
@ -344,21 +349,19 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Only performs the test when a static API key is provided.
Uses the unauthenticated /health endpoint.
Returns:
HealthResponse: A dictionary containing the health status.
"""
# Get the default value from the field definition
default_api_token = self.config.__class__.model_fields["api_token"].default
# Only perform health check if static API key is provided
if not self.config.api_token or self.config.api_token == default_api_token:
return HealthResponse(status=HealthStatus.OK)
try:
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
base_url = self.get_base_url()
health_url = urljoin(base_url, "health")
async with httpx.AsyncClient() as client:
response = await client.get(health_url)
response.raise_for_status()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")

View file

@ -563,34 +563,29 @@ async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
This test verifies that the health method returns a HealthResponse with status OK
when the /health endpoint responds successfully.
"""
# Set a non-default API token to enable health check
vllm_inference_adapter.config.api_token = "real-api-key"
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
# Create a mock async iterator that yields a model when iterated
async def mock_list():
for model in [MagicMock()]:
yield model
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models
mock_create_client.return_value = mock_client
# Create mock client instance
mock_client_instance = MagicMock()
mock_client_instance.get.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
# Verify that models.list was called
mock_models.list.assert_called_once()
# Verify that the health endpoint was called
mock_client_instance.get.assert_called_once()
call_args = mock_client_instance.get.call_args[0]
assert call_args[0].endswith("/health")
async def test_health_status_failure(vllm_inference_adapter):
@ -600,48 +595,42 @@ async def test_health_status_failure(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
# Set a non-default API token to enable health check
vllm_inference_adapter.config.api_token = "real-api-key"
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
# Create a mock async iterator that raises an exception when iterated
async def mock_list():
raise Exception("Connection failed")
yield # Unreachable code
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models
mock_create_client.return_value = mock_client
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock client instance that raises an exception
mock_client_instance = MagicMock()
mock_client_instance.get.side_effect = Exception("Connection failed")
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"]
mock_models.list.assert_called_once()
async def test_health_status_no_static_api_key(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when no static API key is provided.
This test verifies that the health method returns a HealthResponse with status OK
without performing any connectivity test when no static API key is provided.
when the /health endpoint responds successfully, regardless of API token configuration.
"""
# Ensure api_token is the default value (no static API key)
vllm_inference_adapter.config.api_token = "fake"
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Call the health method
health_response = await vllm_inference_adapter.health()
# Create mock client instance
mock_client_instance = MagicMock()
mock_client_instance.get.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Verify the response
assert health_response["status"] == HealthStatus.OK
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
async def test_openai_chat_completion_is_async(vllm_inference_adapter):