From 8de3feb53b600af028ebf03592f57bb098d88aa4 Mon Sep 17 00:00:00 2001 From: Sumit Jaiswal Date: Sun, 1 Jun 2025 18:10:32 +0530 Subject: [PATCH] PR to implement watsonx health check --- .../remote/inference/watsonx/config.py | 4 ++ .../remote/inference/watsonx/watsonx.py | 19 ++++++ .../inference/test_remote_watsonx.py | 66 +++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 tests/unit/providers/inference/test_remote_watsonx.py diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 5eda9c5c0..8486f22e5 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -32,6 +32,10 @@ class WatsonXConfig(BaseModel): default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), description="The Project ID key, only needed of using the hosted service", ) + model_id: str | None = Field( + default_factory=lambda: os.getenv("WATSONX_MODEL_ID", "ibm/granite-3-8b-instruct"), + description="The Model ID key, only needed of using the hosted service", + ) timeout: int = Field( default=60, description="Timeout for the HTTP requests", diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 59f5f5562..a34f007a1 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -40,6 +40,10 @@ from llama_stack.apis.inference.inference import ( TopKSamplingStrategy, TopPSamplingStrategy, ) +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -76,6 +80,21 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async def shutdown(self) -> None: pass + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the Watsonx server. + This method is used by the Provider API to verify + that the service is running correctly. + Returns: + HealthResponse: A dictionary containing the health status. + """ + try: + model = self._get_client(self._config.model_id) + model.generate("test") + return HealthResponse(status=HealthStatus.OK) + except Exception as ex: + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}") + async def completion( self, model_id: str, diff --git a/tests/unit/providers/inference/test_remote_watsonx.py b/tests/unit/providers/inference/test_remote_watsonx.py new file mode 100644 index 000000000..53d0b4cf7 --- /dev/null +++ b/tests/unit/providers/inference/test_remote_watsonx.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +import pytest +import pytest_asyncio + +from llama_stack.providers.datatypes import HealthStatus +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter + + +@pytest.fixture +def watsonx_config(): + """Create a WatsonXConfig fixture for testing.""" + return WatsonXConfig( + url="https://test-watsonx-url.ibm.com", + api_key="test-api-key", + project_id="test-project-id", + model_id="test-model-id", + ) + + +@pytest_asyncio.fixture +async def watsonx_inference_adapter(watsonx_config): + """Create a WatsonX InferenceAdapter fixture for testing.""" + adapter = WatsonXInferenceAdapter(watsonx_config) + await adapter.initialize() + return adapter + + +@pytest.mark.asyncio +async def test_health_success(watsonx_inference_adapter): + """ + Test the health status of WatsonX InferenceAdapter when the connection is successful. + This test verifies that the health method returns a HealthResponse with status OK, only + when the connection to the WatsonX server is successful. + """ + # Mock the _get_client method to return a mock model + mock_model = MagicMock() + mock_model.generate.return_value = "test response" + + with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model): + health_response = await watsonx_inference_adapter.health() + # Verify the response + assert health_response["status"] == HealthStatus.OK + mock_model.generate.assert_called_once_with("test") + + +@pytest.mark.asyncio +async def test_health_failure(watsonx_inference_adapter): + """ + Test the health method of WatsonX InferenceAdapter when the connection fails. + This test verifies that the health method returns a HealthResponse with status ERROR, + with the exception error message. + """ + mock_model = MagicMock() + mock_model.generate.side_effect = Exception("Connection failed") + with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model): + health_response = await watsonx_inference_adapter.health() + assert health_response["status"] == HealthStatus.ERROR + assert "Health check failed: Connection failed" in health_response["message"]