to add health status check for remote vllm

This commit is contained in:
Sumit Jaiswal 2025-05-29 02:10:13 +05:30
parent b21050935e
commit 6d1cf140ba
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
2 changed files with 71 additions and 2 deletions

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
import requests
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -56,7 +57,11 @@ from llama_stack.apis.inference.inference import (
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
@ -298,6 +303,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote VLLM server.
This method is used by initialize() and the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
headers = {}
client = self._create_client() if self.client is None else self.client
if client.api_key:
headers["Authorization"] = f"Bearer {client.api_key}"
models_url = f"{client.base_url}/v1/models"
requests.get(models_url, headers=headers, timeout=10)
return HealthResponse(
status=HealthStatus.OK
)
except Exception as ex:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check failed: {str(ex)}"
)
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")

View file

@ -11,7 +11,7 @@ import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
@ -44,6 +44,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
@ -639,3 +640,42 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == {}
@pytest.mark.asyncio
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
"""
# Mock the requests.get method to return a successful response
with patch('requests.get') as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_health_status_failure(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection fails.
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
# Mock the requests.get method to raise an exception
with patch('requests.get') as mock_get:
mock_get.side_effect = Exception("Connection failed")
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"]
# Verify that requests.get was called
mock_get.assert_called_once()