PR to implement watsonx health check

This commit is contained in:
Sumit Jaiswal 2025-06-01 18:10:32 +05:30
parent e2e15ebb6c
commit 8de3feb53b
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
3 changed files with 89 additions and 0 deletions

View file

@ -32,6 +32,10 @@ class WatsonXConfig(BaseModel):
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
model_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_MODEL_ID", "ibm/granite-3-8b-instruct"),
description="The Model ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",

View file

@ -40,6 +40,10 @@ from llama_stack.apis.inference.inference import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -76,6 +80,21 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async def shutdown(self) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Watsonx server.
This method is used by the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
model = self._get_client(self._config.model_id)
model.generate("test")
return HealthResponse(status=HealthStatus.OK)
except Exception as ex:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}")
async def completion(
self,
model_id: str,

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.fixture
def watsonx_config():
"""Create a WatsonXConfig fixture for testing."""
return WatsonXConfig(
url="https://test-watsonx-url.ibm.com",
api_key="test-api-key",
project_id="test-project-id",
model_id="test-model-id",
)
@pytest_asyncio.fixture
async def watsonx_inference_adapter(watsonx_config):
"""Create a WatsonX InferenceAdapter fixture for testing."""
adapter = WatsonXInferenceAdapter(watsonx_config)
await adapter.initialize()
return adapter
@pytest.mark.asyncio
async def test_health_success(watsonx_inference_adapter):
"""
Test the health status of WatsonX InferenceAdapter when the connection is successful.
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the WatsonX server is successful.
"""
# Mock the _get_client method to return a mock model
mock_model = MagicMock()
mock_model.generate.return_value = "test response"
with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model):
health_response = await watsonx_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
mock_model.generate.assert_called_once_with("test")
@pytest.mark.asyncio
async def test_health_failure(watsonx_inference_adapter):
"""
Test the health method of WatsonX InferenceAdapter when the connection fails.
This test verifies that the health method returns a HealthResponse with status ERROR,
with the exception error message.
"""
mock_model = MagicMock()
mock_model.generate.side_effect = Exception("Connection failed")
with patch.object(watsonx_inference_adapter, "_get_client", return_value=mock_model):
health_response = await watsonx_inference_adapter.health()
assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"]