diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9f38d9abf..ea972a1b7 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator, AsyncIterator from typing import Any import httpx +import requests from openai import AsyncOpenAI from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, @@ -56,7 +57,11 @@ from llama_stack.apis.inference.inference import ( from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models -from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + ModelsProtocolPrivate, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, @@ -298,6 +303,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: pass + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the remote VLLM server. + This method is used by initialize() and the Provider API to verify + that the service is running correctly. + Returns: + HealthResponse: A dictionary containing the health status. + """ + try: + headers = {} + client = self._create_client() if self.client is None else self.client + if client.api_key: + headers["Authorization"] = f"Bearer {client.api_key}" + models_url = f"{client.base_url}/v1/models" + requests.get(models_url, headers=headers, timeout=10) + return HealthResponse( + status=HealthStatus.OK + ) + except Exception as ex: + return HealthResponse( + status=HealthStatus.ERROR, + message=f"Health check failed: {str(ex)}" + ) + async def _get_model(self, model_id: str) -> Model: if not self.model_store: raise ValueError("Model store not set") diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index f9eaee7d6..ed6683c72 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -11,7 +11,7 @@ import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio @@ -44,6 +44,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models import Model from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.vllm import ( VLLMInferenceAdapter, @@ -639,3 +640,42 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args(): assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.arguments == {} + + +@pytest.mark.asyncio +async def test_health_status_success(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when the connection is successful. + + This test verifies that the health method returns a HealthResponse with status OK, only + when the connection to the vLLM server is successful. + """ + # Mock the requests.get method to return a successful response + with patch('requests.get') as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + # Call the health method + health_response = await vllm_inference_adapter.health() + # Verify the response + assert health_response["status"] == HealthStatus.OK + mock_get.assert_called_once() + + +@pytest.mark.asyncio +async def test_health_status_failure(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when the connection fails. + + This test verifies that the health method returns a HealthResponse with status ERROR + and an appropriate error message when the connection to the vLLM server fails. + """ + # Mock the requests.get method to raise an exception + with patch('requests.get') as mock_get: + mock_get.side_effect = Exception("Connection failed") + health_response = await vllm_inference_adapter.health() + # Verify the response + assert health_response["status"] == HealthStatus.ERROR + assert "Health check failed: Connection failed" in health_response["message"] + # Verify that requests.get was called + mock_get.assert_called_once()