diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 5eda9c5c0..8486f22e5 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -32,6 +32,10 @@ class WatsonXConfig(BaseModel): default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), description="The Project ID key, only needed of using the hosted service", ) + model_id: str | None = Field( + default_factory=lambda: os.getenv("WATSONX_MODEL_ID", "ibm/granite-3-8b-instruct"), + description="The Model ID key, only needed of using the hosted service", + ) timeout: int = Field( default=60, description="Timeout for the HTTP requests", diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index c1299e11f..e71b30da8 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -39,6 +40,10 @@ from llama_stack.apis.inference.inference import ( TopKSamplingStrategy, TopPSamplingStrategy, ) +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -75,6 +80,26 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async def shutdown(self) -> None: pass + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the Watsonx server. + This method is used by initialize() and the Provider API to verify + that the service is running correctly. + Returns: + HealthResponse: A dictionary containing the health status. + """ + try: + model = self._get_client(self._config.model_id) + model.generate("test") + return HealthResponse( + status=HealthStatus.OK + ) + except Exception as ex: + return HealthResponse( + status=HealthStatus.ERROR, + message=f"Health check failed: {str(ex)}" + ) + async def completion( self, model_id: str, diff --git a/tests/unit/providers/inference/test_remote_watsonx.py b/tests/unit/providers/inference/test_remote_watsonx.py new file mode 100644 index 000000000..dff49ad20 --- /dev/null +++ b/tests/unit/providers/inference/test_remote_watsonx.py @@ -0,0 +1,56 @@ +from unittest.mock import MagicMock, patch + +import pytest +import pytest_asyncio + +from llama_stack.providers.datatypes import HealthStatus +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter + +@pytest.fixture +def watsonx_config(): + """Create a WatsonXConfig fixture for testing.""" + return WatsonXConfig( + url="https://test-watsonx-url.ibm.com", + api_key="test-api-key", + project_id="test-project-id", + model_id="test-model-id" + ) + +@pytest_asyncio.fixture +async def watsonx_inference_adapter(watsonx_config): + """Create a WatsonX InferenceAdapter fixture for testing.""" + adapter = WatsonXInferenceAdapter(watsonx_config) + await adapter.initialize() + return adapter + +@pytest.mark.asyncio +async def test_health_success(watsonx_inference_adapter): + """ + Test the health status of WatsonX InferenceAdapter when the connection is successful. + This test verifies that the health method returns a HealthResponse with status OK, only + when the connection to the WatsonX server is successful. + """ + # Mock the _get_client method to return a mock model + mock_model = MagicMock() + mock_model.generate.return_value = "test response" + + with patch.object(watsonx_inference_adapter, '_get_client', return_value=mock_model): + health_response = await watsonx_inference_adapter.health() + # Verify the response + assert health_response["status"] == HealthStatus.OK + mock_model.generate.assert_called_once_with("test") + +@pytest.mark.asyncio +async def test_health_failure(watsonx_inference_adapter): + """ + Test the health method of WatsonX InferenceAdapter when the connection fails. + This test verifies that the health method returns a HealthResponse with status ERROR, + with the exception error message. + """ + mock_model = MagicMock() + mock_model.generate.side_effect = Exception("Connection failed") + with patch.object(watsonx_inference_adapter, '_get_client', return_value=mock_model): + health_response = await watsonx_inference_adapter.health() + assert health_response["status"] == HealthStatus.ERROR + assert "Health check failed: Connection failed" in health_response["message"]