to add health status check for remote vllm

This commit is contained in:
Sumit Jaiswal 2025-05-29 02:10:13 +05:30
parent b21050935e
commit 6d1cf140ba
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
2 changed files with 71 additions and 2 deletions

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
import httpx import httpx
import requests
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -56,7 +57,11 @@ from llama_stack.apis.inference.inference import (
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry, build_hf_repo_model_entry,
@ -298,6 +303,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote VLLM server.
This method is used by initialize() and the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
headers = {}
client = self._create_client() if self.client is None else self.client
if client.api_key:
headers["Authorization"] = f"Bearer {client.api_key}"
models_url = f"{client.base_url}/v1/models"
requests.get(models_url, headers=headers, timeout=10)
return HealthResponse(
status=HealthStatus.OK
)
except Exception as ex:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check failed: {str(ex)}"
)
async def _get_model(self, model_id: str) -> Model: async def _get_model(self, model_id: str) -> Model:
if not self.model_store: if not self.model_store:
raise ValueError("Model store not set") raise ValueError("Model store not set")

View file

@ -11,7 +11,7 @@ import threading
import time import time
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -44,6 +44,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import ( from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter, VLLMInferenceAdapter,
@ -639,3 +640,42 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == {} assert chunks[-2].event.delta.tool_call.arguments == {}
@pytest.mark.asyncio
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
"""
# Mock the requests.get method to return a successful response
with patch('requests.get') as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_health_status_failure(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection fails.
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
# Mock the requests.get method to raise an exception
with patch('requests.get') as mock_get:
mock_get.side_effect = Exception("Connection failed")
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"]
# Verify that requests.get was called
mock_get.assert_called_once()