PR to implement watsonx health check

This commit is contained in:
Sumit Jaiswal 2025-06-01 18:10:32 +05:30
parent e2e15ebb6c
commit 8de3feb53b
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
3 changed files with 89 additions and 0 deletions

View file

@ -32,6 +32,10 @@ class WatsonXConfig(BaseModel):
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
model_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_MODEL_ID", "ibm/granite-3-8b-instruct"),
description="The Model ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",

View file

@ -40,6 +40,10 @@ from llama_stack.apis.inference.inference import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -76,6 +80,21 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async def shutdown(self) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Watsonx server.
This method is used by the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
model = self._get_client(self._config.model_id)
model.generate("test")
return HealthResponse(status=HealthStatus.OK)
except Exception as ex:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}")
async def completion(
self,
model_id: str,