mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
feat: Add health check endpoint to inference
This commit introduces a HealthResponse model to standardize health status responses across the inference service. The health() method has been implemented in the Inference API, allowing the retrieval of system health information. The InferenceRouter has been updated to aggregate health statuses from various providers, ensuring a comprehensive view of system health. Additionally, the health() method has been added to multiple inference provider implementations, with a default NotImplementedError where necessary. A new /inference/health endpoint is now available, enabling monitoring and diagnostics of the inference service to improve observability and maintainability. Test: ``` $ curl http://127.0.0.1:8321/v1/inference/health {"ollama":{"health":{"status":"OK"}},"sentence-transformers":{"health":{"status":"Not Implemented"}}} ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
6b1773d530
commit
4d8d701e76
21 changed files with 242 additions and 1 deletions
42
docs/_static/llama-stack-spec.html
vendored
42
docs/_static/llama-stack-spec.html
vendored
|
@ -1045,6 +1045,27 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v1/inference/health": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A dictionary containing the health status of the inference service.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HealthResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inference"
|
||||
],
|
||||
"description": "Retrieve the health status of the inference service.",
|
||||
"parameters": []
|
||||
}
|
||||
},
|
||||
"/v1/models/{model_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -5742,6 +5763,27 @@
|
|||
"type"
|
||||
]
|
||||
},
|
||||
"HealthResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"health": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"OK",
|
||||
"Error",
|
||||
"Not Implemented"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"health"
|
||||
],
|
||||
"description": "HealthResponse is a model representing the health status response.\nparam health: A dictionary containing health information."
|
||||
},
|
||||
"Model": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
33
docs/_static/llama-stack-spec.yaml
vendored
33
docs/_static/llama-stack-spec.yaml
vendored
|
@ -629,6 +629,21 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/inference/health:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A dictionary containing the health status of the inference service.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HealthResponse'
|
||||
tags:
|
||||
- Inference
|
||||
description: >-
|
||||
Retrieve the health status of the inference service.
|
||||
parameters: []
|
||||
/v1/models/{model_id}:
|
||||
get:
|
||||
responses:
|
||||
|
@ -3673,6 +3688,24 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
HealthResponse:
|
||||
type: object
|
||||
properties:
|
||||
health:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
enum:
|
||||
- OK
|
||||
- Error
|
||||
- Not Implemented
|
||||
additionalProperties: false
|
||||
required:
|
||||
- health
|
||||
description: >-
|
||||
HealthResponse is a model representing the health status response.
|
||||
|
||||
param health: A dictionary containing health information.
|
||||
Model:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -389,6 +389,30 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
OK = "OK"
|
||||
ERROR = "Error"
|
||||
NOT_IMPLEMENTED = "Not Implemented"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class HealthResponse(BaseModel):
|
||||
"""
|
||||
HealthResponse is a model representing the health status response.
|
||||
|
||||
param health: A dictionary containing health information.
|
||||
"""
|
||||
|
||||
health: Dict[str, HealthStatus]
|
||||
|
||||
@field_validator("health")
|
||||
@classmethod
|
||||
def check_status_present(cls, v):
|
||||
if "status" not in v:
|
||||
raise ValueError("'status' must be present in the health dictionary.")
|
||||
return v
|
||||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
@ -481,3 +505,11 @@ class Inference(Protocol):
|
|||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/health", method="GET")
|
||||
async def get_health(self) -> HealthResponse:
|
||||
"""Retrieve the health status of the inference service.
|
||||
|
||||
:returns: A dictionary containing the health status of the inference service.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -28,7 +28,6 @@ class RouteInfo(BaseModel):
|
|||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
status: str
|
||||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -17,6 +17,8 @@ from llama_stack.apis.eval import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -210,6 +212,17 @@ class InferenceRouter(Inference):
|
|||
contents=contents,
|
||||
)
|
||||
|
||||
async def get_health(self) -> HealthResponse:
|
||||
health_statuses = {}
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
health_statuses[provider_id] = await impl.health()
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.NOT_IMPLEMENTED})
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)})
|
||||
return health_statuses
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
|
@ -428,3 +429,8 @@ class MetaReferenceInferenceImpl(
|
|||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -75,3 +76,8 @@ class SentenceTransformersInferenceImpl(
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -230,5 +231,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -198,3 +199,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -191,3 +192,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -140,3 +141,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -297,3 +298,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
|
@ -154,3 +155,8 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
||||
)
|
||||
return Groq(api_key=provider_data.groq_api_key)
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
|
@ -201,3 +202,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -22,6 +22,8 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -369,6 +371,22 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
return model
|
||||
|
||||
async def get_health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by initializing the service.
|
||||
|
||||
This method is used by the inspect API endpoint to verify that the service is running
|
||||
correctly.
|
||||
|
||||
Returns:
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.initialize()
|
||||
return HealthResponse(health={"status": HealthStatus.OK})
|
||||
except ConnectionError as e:
|
||||
return HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)})
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
|
@ -125,3 +125,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -326,3 +326,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
]
|
||||
|
||||
return compitable_tool_calls
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -308,3 +309,8 @@ class InferenceEndpointAdapter(_HfAdapter):
|
|||
self.client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -274,3 +275,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
HealthResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -375,3 +376,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def get_health(
|
||||
self,
|
||||
) -> HealthResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
22
llama_stack/providers/tests/inference/test_health.py
Normal file
22
llama_stack/providers/tests/inference/test_health.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import HealthResponse
|
||||
|
||||
# How to run this test:
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_health.py
|
||||
|
||||
|
||||
class TestHeatlh:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health(self, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
response = await inference_impl.health()
|
||||
for key in response:
|
||||
assert isinstance(response[key], HealthResponse)
|
||||
assert response[key].health["status"] == "OK", response
|
Loading…
Add table
Add a link
Reference in a new issue