From 4d8d701e767829e6e078b23ee114b4967216c8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 18 Feb 2025 12:24:41 +0100 Subject: [PATCH] feat: Add health check endpoint to inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces a HealthResponse model to standardize health status responses across the inference service. The health() method has been implemented in the Inference API, allowing the retrieval of system health information. The InferenceRouter has been updated to aggregate health statuses from various providers, ensuring a comprehensive view of system health. Additionally, the health() method has been added to multiple inference provider implementations, with a default NotImplementedError where necessary. A new /inference/health endpoint is now available, enabling monitoring and diagnostics of the inference service to improve observability and maintainability. Test: ``` $ curl http://127.0.0.1:8321/v1/inference/health {"ollama":{"health":{"status":"OK"}},"sentence-transformers":{"health":{"status":"Not Implemented"}}} ``` Signed-off-by: Sébastien Han --- docs/_static/llama-stack-spec.html | 42 +++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 33 +++++++++++++++ llama_stack/apis/inference/inference.py | 32 ++++++++++++++ llama_stack/apis/inspect/inspect.py | 1 - llama_stack/distribution/routers/routers.py | 13 ++++++ .../inference/meta_reference/inference.py | 6 +++ .../sentence_transformers.py | 6 +++ .../providers/inline/inference/vllm/vllm.py | 6 +++ .../remote/inference/bedrock/bedrock.py | 6 +++ .../remote/inference/cerebras/cerebras.py | 6 +++ .../remote/inference/databricks/databricks.py | 6 +++ .../remote/inference/fireworks/fireworks.py | 6 +++ .../providers/remote/inference/groq/groq.py | 6 +++ .../remote/inference/nvidia/nvidia.py | 6 +++ .../remote/inference/ollama/ollama.py | 18 ++++++++ .../remote/inference/runpod/runpod.py | 5 +++ .../remote/inference/sambanova/sambanova.py | 5 +++ .../providers/remote/inference/tgi/tgi.py | 6 +++ .../remote/inference/together/together.py | 6 +++ .../providers/remote/inference/vllm/vllm.py | 6 +++ .../providers/tests/inference/test_health.py | 22 ++++++++++ 21 files changed, 242 insertions(+), 1 deletion(-) create mode 100644 llama_stack/providers/tests/inference/test_health.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 17cf92341..423bd27e6 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1045,6 +1045,27 @@ ] } }, + "/v1/inference/health": { + "get": { + "responses": { + "200": { + "description": "A dictionary containing the health status of the inference service.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HealthResponse" + } + } + } + } + }, + "tags": [ + "Inference" + ], + "description": "Retrieve the health status of the inference service.", + "parameters": [] + } + }, "/v1/models/{model_id}": { "get": { "responses": { @@ -5742,6 +5763,27 @@ "type" ] }, + "HealthResponse": { + "type": "object", + "properties": { + "health": { + "type": "object", + "additionalProperties": { + "type": "string", + "enum": [ + "OK", + "Error", + "Not Implemented" + ] + } + } + }, + "additionalProperties": false, + "required": [ + "health" + ], + "description": "HealthResponse is a model representing the health status response.\nparam health: A dictionary containing health information." + }, "Model": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index f63374406..aef989bdb 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -629,6 +629,21 @@ paths: required: true schema: type: string + /v1/inference/health: + get: + responses: + '200': + description: >- + A dictionary containing the health status of the inference service. + content: + application/json: + schema: + $ref: '#/components/schemas/HealthResponse' + tags: + - Inference + description: >- + Retrieve the health status of the inference service. + parameters: [] /v1/models/{model_id}: get: responses: @@ -3673,6 +3688,24 @@ components: additionalProperties: false required: - type + HealthResponse: + type: object + properties: + health: + type: object + additionalProperties: + type: string + enum: + - OK + - Error + - Not Implemented + additionalProperties: false + required: + - health + description: >- + HealthResponse is a model representing the health status response. + + param health: A dictionary containing health information. Model: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 433ba3274..571ef61b4 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -389,6 +389,30 @@ class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] +class HealthStatus(str, Enum): + OK = "OK" + ERROR = "Error" + NOT_IMPLEMENTED = "Not Implemented" + + +@json_schema_type +class HealthResponse(BaseModel): + """ + HealthResponse is a model representing the health status response. + + param health: A dictionary containing health information. + """ + + health: Dict[str, HealthStatus] + + @field_validator("health") + @classmethod + def check_status_present(cls, v): + if "status" not in v: + raise ValueError("'status' must be present in the health dictionary.") + return v + + class ModelStore(Protocol): def get_model(self, identifier: str) -> Model: ... @@ -481,3 +505,11 @@ class Inference(Protocol): :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ ... + + @webmethod(route="/inference/health", method="GET") + async def get_health(self) -> HealthResponse: + """Retrieve the health status of the inference service. + + :returns: A dictionary containing the health status of the inference service. + """ + ... diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 4a647a2d9..3084cf03e 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -28,7 +28,6 @@ class RouteInfo(BaseModel): @json_schema_type class HealthInfo(BaseModel): status: str - # TODO: add a provider level status @json_schema_type diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f45975189..79707284c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -17,6 +17,8 @@ from llama_stack.apis.eval import ( ) from llama_stack.apis.inference import ( EmbeddingsResponse, + HealthResponse, + HealthStatus, Inference, LogProbConfig, Message, @@ -210,6 +212,17 @@ class InferenceRouter(Inference): contents=contents, ) + async def get_health(self) -> HealthResponse: + health_statuses = {} + for provider_id, impl in self.routing_table.impls_by_provider_id.items(): + try: + health_statuses[provider_id] = await impl.health() + except NotImplementedError: + health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.NOT_IMPLEMENTED}) + except Exception as e: + health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)}) + return health_statuses + class SafetyRouter(Safety): def __init__( diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index c79f97def..269d17d60 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, + HealthResponse, Inference, InterleavedContent, LogProbConfig, @@ -428,3 +429,8 @@ class MetaReferenceInferenceImpl( else: for x in impl(): yield x + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 6a83836e6..37531004f 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union from llama_stack.apis.inference import ( CompletionResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -75,3 +76,8 @@ class SentenceTransformersInferenceImpl( tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 5536ea3a5..b5d1beb69 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -230,5 +231,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() + async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index e896f0597..f14eb49f3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -198,3 +199,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): response_body = json.loads(response.get("body").read()) embeddings.append(response_body.get("embedding")) return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 1ce267e8d..fceff5e77 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -16,6 +16,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -191,3 +192,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 3d306e61f..106ca38e9 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -15,6 +15,7 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -140,3 +141,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index acf37b248..7119a77b5 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -297,3 +298,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 441b6af5c..6f0e93e28 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, InterleavedContent, LogProbConfig, @@ -154,3 +155,8 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD 'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "" }' ) return Groq(api_key=provider_data.groq_api_key) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 0c5b7c454..42ade17b7 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, InterleavedContent, LogProbConfig, @@ -201,3 +202,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f524c0734..e99342435 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -22,6 +22,8 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + HealthResponse, + HealthStatus, Inference, LogProbConfig, Message, @@ -369,6 +371,22 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return model + async def get_health(self) -> HealthResponse: + """ + Performs a health check by initializing the service. + + This method is used by the inspect API endpoint to verify that the service is running + correctly. + + Returns: + HealthResponse: A dictionary containing the health status. + """ + try: + await self.initialize() + return HealthResponse(health={"status": HealthStatus.OK}) + except ConnectionError as e: + return HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)}) + async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 1abb17336..f60ea192e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -125,3 +125,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b906e0dcb..5da55bd2b 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -326,3 +326,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): ] return compitable_tool_calls + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 1909e01f8..481768c40 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -18,6 +18,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -308,3 +309,8 @@ class InferenceEndpointAdapter(_HfAdapter): self.client = endpoint.async_client self.model_id = endpoint.repository self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 054501da8..84ff8166a 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -16,6 +16,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -274,3 +275,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi ) embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index b22284302..094294b7c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -375,3 +376,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/test_health.py b/llama_stack/providers/tests/inference/test_health.py new file mode 100644 index 000000000..7c92e11cb --- /dev/null +++ b/llama_stack/providers/tests/inference/test_health.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import HealthResponse + +# How to run this test: +# pytest -v -s llama_stack/providers/tests/inference/test_health.py + + +class TestHeatlh: + @pytest.mark.asyncio + async def test_health(self, inference_stack): + inference_impl, _ = inference_stack + response = await inference_impl.health() + for key in response: + assert isinstance(response[key], HealthResponse) + assert response[key].health["status"] == "OK", response