diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 0fa6b537c..a34f007a1 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -91,14 +91,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): try: model = self._get_client(self._config.model_id) model.generate("test") - return HealthResponse( - status=HealthStatus.OK - ) + return HealthResponse(status=HealthStatus.OK) except Exception as ex: - return HealthResponse( - status=HealthStatus.ERROR, - message=f"Health check failed: {str(ex)}" - ) + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}") async def completion( self, diff --git a/tests/unit/providers/inference/test_remote_watsonx.py b/tests/unit/providers/inference/test_remote_watsonx.py index dff49ad20..53d0b4cf7 100644 --- a/tests/unit/providers/inference/test_remote_watsonx.py +++ b/tests/unit/providers/inference/test_remote_watsonx.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from unittest.mock import MagicMock, patch import pytest @@ -7,6 +13,7 @@ from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter + @pytest.fixture def watsonx_config(): """Create a WatsonXConfig fixture for testing.""" @@ -14,9 +21,10 @@ def watsonx_config(): url="https://test-watsonx-url.ibm.com", api_key="test-api-key", project_id="test-project-id", - model_id="test-model-id" + model_id="test-model-id", ) + @pytest_asyncio.fixture async def watsonx_inference_adapter(watsonx_config): """Create a WatsonX InferenceAdapter fixture for testing.""" @@ -24,6 +32,7 @@ async def watsonx_inference_adapter(watsonx_config): await adapter.initialize() return adapter + @pytest.mark.asyncio async def test_health_success(watsonx_inference_adapter): """ @@ -35,12 +44,13 @@ async def test_health_success(watsonx_inference_adapter): mock_model = MagicMock() mock_model.generate.return_value = "test response" - with patch.object(watsonx_inference_adapter, '_get_client', return_value=mock_model): + with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model): health_response = await watsonx_inference_adapter.health() # Verify the response assert health_response["status"] == HealthStatus.OK mock_model.generate.assert_called_once_with("test") + @pytest.mark.asyncio async def test_health_failure(watsonx_inference_adapter): """ @@ -50,7 +60,7 @@ async def test_health_failure(watsonx_inference_adapter): """ mock_model = MagicMock() mock_model.generate.side_effect = Exception("Connection failed") - with patch.object(watsonx_inference_adapter, '_get_client', return_value=mock_model): + with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model): health_response = await watsonx_inference_adapter.health() assert health_response["status"] == HealthStatus.ERROR assert "Health check failed: Connection failed" in health_response["message"]