updates to fix pre-commit checks

This commit is contained in:
Sumit Jaiswal 2025-06-01 17:51:02 +05:30
parent 6ec2ed4196
commit 319300fe24
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
2 changed files with 15 additions and 10 deletions

View file

@ -91,14 +91,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
try: try:
model = self._get_client(self._config.model_id) model = self._get_client(self._config.model_id)
model.generate("test") model.generate("test")
return HealthResponse( return HealthResponse(status=HealthStatus.OK)
status=HealthStatus.OK
)
except Exception as ex: except Exception as ex:
return HealthResponse( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}")
status=HealthStatus.ERROR,
message=f"Health check failed: {str(ex)}"
)
async def completion( async def completion(
self, self,

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -7,6 +13,7 @@ from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.fixture @pytest.fixture
def watsonx_config(): def watsonx_config():
"""Create a WatsonXConfig fixture for testing.""" """Create a WatsonXConfig fixture for testing."""
@ -14,9 +21,10 @@ def watsonx_config():
url="https://test-watsonx-url.ibm.com", url="https://test-watsonx-url.ibm.com",
api_key="test-api-key", api_key="test-api-key",
project_id="test-project-id", project_id="test-project-id",
model_id="test-model-id" model_id="test-model-id",
) )
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def watsonx_inference_adapter(watsonx_config): async def watsonx_inference_adapter(watsonx_config):
"""Create a WatsonX InferenceAdapter fixture for testing.""" """Create a WatsonX InferenceAdapter fixture for testing."""
@ -24,6 +32,7 @@ async def watsonx_inference_adapter(watsonx_config):
await adapter.initialize() await adapter.initialize()
return adapter return adapter
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_success(watsonx_inference_adapter): async def test_health_success(watsonx_inference_adapter):
""" """
@ -35,12 +44,13 @@ async def test_health_success(watsonx_inference_adapter):
mock_model = MagicMock() mock_model = MagicMock()
mock_model.generate.return_value = "test response" mock_model.generate.return_value = "test response"
with patch.object(watsonx_inference_adapter, '_get_client', return_value=mock_model): with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model):
health_response = await watsonx_inference_adapter.health() health_response = await watsonx_inference_adapter.health()
# Verify the response # Verify the response
assert health_response["status"] == HealthStatus.OK assert health_response["status"] == HealthStatus.OK
mock_model.generate.assert_called_once_with("test") mock_model.generate.assert_called_once_with("test")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_failure(watsonx_inference_adapter): async def test_health_failure(watsonx_inference_adapter):
""" """
@ -50,7 +60,7 @@ async def test_health_failure(watsonx_inference_adapter):
""" """
mock_model = MagicMock() mock_model = MagicMock()
mock_model.generate.side_effect = Exception("Connection failed") mock_model.generate.side_effect = Exception("Connection failed")
with patch.object(watsonx_inference_adapter, '_get_client', return_value=mock_model): with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model):
health_response = await watsonx_inference_adapter.health() health_response = await watsonx_inference_adapter.health()
assert health_response["status"] == HealthStatus.ERROR assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"] assert "Health check failed: Connection failed" in health_response["message"]