feat: Add health check endpoint to inference

This commit introduces a HealthResponse model to standardize health
status responses across the inference service. The health() method has
been implemented in the Inference API, allowing the retrieval of system
health information.

The InferenceRouter has been updated to aggregate health statuses from
various providers, ensuring a comprehensive view of system health.
Additionally, the health() method has been added to multiple inference
provider implementations, with a default NotImplementedError where
necessary.

A new /inference/health endpoint is now available, enabling monitoring
and diagnostics of the inference service to improve observability and
maintainability.

Test:

```
$ curl http://127.0.0.1:8321/v1/inference/health
{"ollama":{"health":{"status":"OK"}},"sentence-transformers":{"health":{"status":"Not Implemented"}}}
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-02-18 12:24:41 +01:00
parent 6b1773d530
commit 4d8d701e76
No known key found for this signature in database
21 changed files with 242 additions and 1 deletions

View file

@ -1045,6 +1045,27 @@
]
}
},
"/v1/inference/health": {
"get": {
"responses": {
"200": {
"description": "A dictionary containing the health status of the inference service.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HealthResponse"
}
}
}
}
},
"tags": [
"Inference"
],
"description": "Retrieve the health status of the inference service.",
"parameters": []
}
},
"/v1/models/{model_id}": {
"get": {
"responses": {
@ -5742,6 +5763,27 @@
"type"
]
},
"HealthResponse": {
"type": "object",
"properties": {
"health": {
"type": "object",
"additionalProperties": {
"type": "string",
"enum": [
"OK",
"Error",
"Not Implemented"
]
}
}
},
"additionalProperties": false,
"required": [
"health"
],
"description": "HealthResponse is a model representing the health status response.\nparam health: A dictionary containing health information."
},
"Model": {
"type": "object",
"properties": {

View file

@ -629,6 +629,21 @@ paths:
required: true
schema:
type: string
/v1/inference/health:
get:
responses:
'200':
description: >-
A dictionary containing the health status of the inference service.
content:
application/json:
schema:
$ref: '#/components/schemas/HealthResponse'
tags:
- Inference
description: >-
Retrieve the health status of the inference service.
parameters: []
/v1/models/{model_id}:
get:
responses:
@ -3673,6 +3688,24 @@ components:
additionalProperties: false
required:
- type
HealthResponse:
type: object
properties:
health:
type: object
additionalProperties:
type: string
enum:
- OK
- Error
- Not Implemented
additionalProperties: false
required:
- health
description: >-
HealthResponse is a model representing the health status response.
param health: A dictionary containing health information.
Model:
type: object
properties:

View file

@ -389,6 +389,30 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class HealthStatus(str, Enum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"
@json_schema_type
class HealthResponse(BaseModel):
"""
HealthResponse is a model representing the health status response.
param health: A dictionary containing health information.
"""
health: Dict[str, HealthStatus]
@field_validator("health")
@classmethod
def check_status_present(cls, v):
if "status" not in v:
raise ValueError("'status' must be present in the health dictionary.")
return v
class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
@ -481,3 +505,11 @@ class Inference(Protocol):
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
@webmethod(route="/inference/health", method="GET")
async def get_health(self) -> HealthResponse:
"""Retrieve the health status of the inference service.
:returns: A dictionary containing the health status of the inference service.
"""
...

View file

@ -28,7 +28,6 @@ class RouteInfo(BaseModel):
@json_schema_type
class HealthInfo(BaseModel):
status: str
# TODO: add a provider level status
@json_schema_type

View file

@ -17,6 +17,8 @@ from llama_stack.apis.eval import (
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
HealthResponse,
HealthStatus,
Inference,
LogProbConfig,
Message,
@ -210,6 +212,17 @@ class InferenceRouter(Inference):
contents=contents,
)
async def get_health(self) -> HealthResponse:
health_statuses = {}
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
health_statuses[provider_id] = await impl.health()
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.NOT_IMPLEMENTED})
except Exception as e:
health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)})
return health_statuses
class SafetyRouter(Safety):
def __init__(

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
HealthResponse,
Inference,
InterleavedContent,
LogProbConfig,
@ -428,3 +429,8 @@ class MetaReferenceInferenceImpl(
else:
for x in impl():
yield x
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -75,3 +76,8 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -230,5 +231,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
yield chunk
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -198,3 +199,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -191,3 +192,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -15,6 +15,7 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -140,3 +141,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -297,3 +298,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
HealthResponse,
Inference,
InterleavedContent,
LogProbConfig,
@ -154,3 +155,8 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
)
return Groq(api_key=provider_data.groq_api_key)
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
HealthResponse,
Inference,
InterleavedContent,
LogProbConfig,
@ -201,3 +202,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -22,6 +22,8 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
HealthResponse,
HealthStatus,
Inference,
LogProbConfig,
Message,
@ -369,6 +371,22 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return model
async def get_health(self) -> HealthResponse:
"""
Performs a health check by initializing the service.
This method is used by the inspect API endpoint to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
await self.initialize()
return HealthResponse(health={"status": HealthStatus.OK})
except ConnectionError as e:
return HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)})
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:

View file

@ -125,3 +125,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -326,3 +326,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
]
return compitable_tool_calls
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -18,6 +18,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -308,3 +309,8 @@ class InferenceEndpointAdapter(_HfAdapter):
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -274,3 +275,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
HealthResponse,
Inference,
LogProbConfig,
Message,
@ -375,3 +376,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def get_health(
self,
) -> HealthResponse:
raise NotImplementedError()

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import HealthResponse
# How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_health.py
class TestHeatlh:
@pytest.mark.asyncio
async def test_health(self, inference_stack):
inference_impl, _ = inference_stack
response = await inference_impl.health()
for key in response:
assert isinstance(response[key], HealthResponse)
assert response[key].health["status"] == "OK", response