watsonx health check implementation

This commit is contained in:
Sumit Jaiswal 2025-05-29 01:41:22 +05:30
parent a654467552
commit 2e81a8f020
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
3 changed files with 85 additions and 0 deletions

View file

@ -32,6 +32,10 @@ class WatsonXConfig(BaseModel):
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
model_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_MODEL_ID", "ibm/granite-3-8b-instruct"),
description="The Model ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -39,6 +40,10 @@ from llama_stack.apis.inference.inference import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -75,6 +80,26 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async def shutdown(self) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Watsonx server.
This method is used by initialize() and the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
model = self._get_client(self._config.model_id)
model.generate("test")
return HealthResponse(
status=HealthStatus.OK
)
except Exception as ex:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check failed: {str(ex)}"
)
async def completion(
self,
model_id: str,