mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 04:17:32 +00:00
Update vLLM health check to use /health endpoint
- Replace models.list() call with HTTP GET to /health endpoint - Remove API token validation since /health is unauthenticated - Use urllib.parse.urljoin for cleaner URL construction - Update tests to mock httpx.AsyncClient instead of OpenAI client - Health check now works regardless of API token configuration Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
parent
5e74bc7fcf
commit
67728bfccf
2 changed files with 48 additions and 56 deletions
|
@ -6,6 +6,7 @@
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import APIConnectionError, AsyncOpenAI
|
from openai import APIConnectionError, AsyncOpenAI
|
||||||
|
@ -316,6 +317,10 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
||||||
)
|
)
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
|
# Get the default value from the field definition
|
||||||
|
default_api_token = self.config.__class__.model_fields["api_token"].default
|
||||||
|
if not self.config.api_token or self.config.api_token == default_api_token:
|
||||||
|
return False
|
||||||
return self.config.refresh_models
|
return self.config.refresh_models
|
||||||
|
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
|
@ -344,21 +349,19 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
||||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||||
This method is used by the Provider API to verify
|
This method is used by the Provider API to verify
|
||||||
that the service is running correctly.
|
that the service is running correctly.
|
||||||
Only performs the test when a static API key is provided.
|
Uses the unauthenticated /health endpoint.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
HealthResponse: A dictionary containing the health status.
|
HealthResponse: A dictionary containing the health status.
|
||||||
"""
|
"""
|
||||||
# Get the default value from the field definition
|
|
||||||
default_api_token = self.config.__class__.model_fields["api_token"].default
|
|
||||||
|
|
||||||
# Only perform health check if static API key is provided
|
|
||||||
if not self.config.api_token or self.config.api_token == default_api_token:
|
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
|
base_url = self.get_base_url()
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
health_url = urljoin(base_url, "health")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(health_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
|
|
|
@ -563,34 +563,29 @@ async def test_health_status_success(vllm_inference_adapter):
|
||||||
"""
|
"""
|
||||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||||
|
|
||||||
This test verifies that the health method returns a HealthResponse with status OK, only
|
This test verifies that the health method returns a HealthResponse with status OK
|
||||||
when the connection to the vLLM server is successful.
|
when the /health endpoint responds successfully.
|
||||||
"""
|
"""
|
||||||
# Set a non-default API token to enable health check
|
with patch("httpx.AsyncClient") as mock_client_class:
|
||||||
vllm_inference_adapter.config.api_token = "real-api-key"
|
# Create mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
# Create mock client instance
|
||||||
# Create mock client and models
|
mock_client_instance = MagicMock()
|
||||||
mock_client = MagicMock()
|
mock_client_instance.get.return_value = mock_response
|
||||||
mock_models = MagicMock()
|
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
# Create a mock async iterator that yields a model when iterated
|
|
||||||
async def mock_list():
|
|
||||||
for model in [MagicMock()]:
|
|
||||||
yield model
|
|
||||||
|
|
||||||
# Set up the models.list to return our mock async iterator
|
|
||||||
mock_models.list.return_value = mock_list()
|
|
||||||
mock_client.models = mock_models
|
|
||||||
mock_create_client.return_value = mock_client
|
|
||||||
|
|
||||||
# Call the health method
|
# Call the health method
|
||||||
health_response = await vllm_inference_adapter.health()
|
health_response = await vllm_inference_adapter.health()
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert health_response["status"] == HealthStatus.OK
|
assert health_response["status"] == HealthStatus.OK
|
||||||
|
|
||||||
# Verify that models.list was called
|
# Verify that the health endpoint was called
|
||||||
mock_models.list.assert_called_once()
|
mock_client_instance.get.assert_called_once()
|
||||||
|
call_args = mock_client_instance.get.call_args[0]
|
||||||
|
assert call_args[0].endswith("/health")
|
||||||
|
|
||||||
|
|
||||||
async def test_health_status_failure(vllm_inference_adapter):
|
async def test_health_status_failure(vllm_inference_adapter):
|
||||||
|
@ -600,48 +595,42 @@ async def test_health_status_failure(vllm_inference_adapter):
|
||||||
This test verifies that the health method returns a HealthResponse with status ERROR
|
This test verifies that the health method returns a HealthResponse with status ERROR
|
||||||
and an appropriate error message when the connection to the vLLM server fails.
|
and an appropriate error message when the connection to the vLLM server fails.
|
||||||
"""
|
"""
|
||||||
# Set a non-default API token to enable health check
|
with patch("httpx.AsyncClient") as mock_client_class:
|
||||||
vllm_inference_adapter.config.api_token = "real-api-key"
|
# Create mock client instance that raises an exception
|
||||||
|
mock_client_instance = MagicMock()
|
||||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
mock_client_instance.get.side_effect = Exception("Connection failed")
|
||||||
# Create mock client and models
|
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_models = MagicMock()
|
|
||||||
|
|
||||||
# Create a mock async iterator that raises an exception when iterated
|
|
||||||
async def mock_list():
|
|
||||||
raise Exception("Connection failed")
|
|
||||||
yield # Unreachable code
|
|
||||||
|
|
||||||
# Set up the models.list to return our mock async iterator
|
|
||||||
mock_models.list.return_value = mock_list()
|
|
||||||
mock_client.models = mock_models
|
|
||||||
mock_create_client.return_value = mock_client
|
|
||||||
|
|
||||||
# Call the health method
|
# Call the health method
|
||||||
health_response = await vllm_inference_adapter.health()
|
health_response = await vllm_inference_adapter.health()
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert health_response["status"] == HealthStatus.ERROR
|
assert health_response["status"] == HealthStatus.ERROR
|
||||||
assert "Health check failed: Connection failed" in health_response["message"]
|
assert "Health check failed: Connection failed" in health_response["message"]
|
||||||
|
|
||||||
mock_models.list.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_health_status_no_static_api_key(vllm_inference_adapter):
|
async def test_health_status_no_static_api_key(vllm_inference_adapter):
|
||||||
"""
|
"""
|
||||||
Test the health method of VLLM InferenceAdapter when no static API key is provided.
|
Test the health method of VLLM InferenceAdapter when no static API key is provided.
|
||||||
|
|
||||||
This test verifies that the health method returns a HealthResponse with status OK
|
This test verifies that the health method returns a HealthResponse with status OK
|
||||||
without performing any connectivity test when no static API key is provided.
|
when the /health endpoint responds successfully, regardless of API token configuration.
|
||||||
"""
|
"""
|
||||||
# Ensure api_token is the default value (no static API key)
|
with patch("httpx.AsyncClient") as mock_client_class:
|
||||||
vllm_inference_adapter.config.api_token = "fake"
|
# Create mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
# Call the health method
|
# Create mock client instance
|
||||||
health_response = await vllm_inference_adapter.health()
|
mock_client_instance = MagicMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
# Verify the response
|
# Call the health method
|
||||||
assert health_response["status"] == HealthStatus.OK
|
health_response = await vllm_inference_adapter.health()
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert health_response["status"] == HealthStatus.OK
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue