diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 17cf92341..423bd27e6 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1045,6 +1045,27 @@ ] } }, + "/v1/inference/health": { + "get": { + "responses": { + "200": { + "description": "A dictionary containing the health status of the inference service.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HealthResponse" + } + } + } + } + }, + "tags": [ + "Inference" + ], + "description": "Retrieve the health status of the inference service.", + "parameters": [] + } + }, "/v1/models/{model_id}": { "get": { "responses": { @@ -5742,6 +5763,27 @@ "type" ] }, + "HealthResponse": { + "type": "object", + "properties": { + "health": { + "type": "object", + "additionalProperties": { + "type": "string", + "enum": [ + "OK", + "Error", + "Not Implemented" + ] + } + } + }, + "additionalProperties": false, + "required": [ + "health" + ], + "description": "HealthResponse is a model representing the health status response.\nparam health: A dictionary containing health information." + }, "Model": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index f63374406..aef989bdb 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -629,6 +629,21 @@ paths: required: true schema: type: string + /v1/inference/health: + get: + responses: + '200': + description: >- + A dictionary containing the health status of the inference service. + content: + application/json: + schema: + $ref: '#/components/schemas/HealthResponse' + tags: + - Inference + description: >- + Retrieve the health status of the inference service. + parameters: [] /v1/models/{model_id}: get: responses: @@ -3673,6 +3688,24 @@ components: additionalProperties: false required: - type + HealthResponse: + type: object + properties: + health: + type: object + additionalProperties: + type: string + enum: + - OK + - Error + - Not Implemented + additionalProperties: false + required: + - health + description: >- + HealthResponse is a model representing the health status response. + + param health: A dictionary containing health information. Model: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 433ba3274..571ef61b4 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -389,6 +389,30 @@ class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] +class HealthStatus(str, Enum): + OK = "OK" + ERROR = "Error" + NOT_IMPLEMENTED = "Not Implemented" + + +@json_schema_type +class HealthResponse(BaseModel): + """ + HealthResponse is a model representing the health status response. + + param health: A dictionary containing health information. + """ + + health: Dict[str, HealthStatus] + + @field_validator("health") + @classmethod + def check_status_present(cls, v): + if "status" not in v: + raise ValueError("'status' must be present in the health dictionary.") + return v + + class ModelStore(Protocol): def get_model(self, identifier: str) -> Model: ... @@ -481,3 +505,11 @@ class Inference(Protocol): :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ ... + + @webmethod(route="/inference/health", method="GET") + async def get_health(self) -> HealthResponse: + """Retrieve the health status of the inference service. + + :returns: A dictionary containing the health status of the inference service. + """ + ... diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 4a647a2d9..3084cf03e 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -28,7 +28,6 @@ class RouteInfo(BaseModel): @json_schema_type class HealthInfo(BaseModel): status: str - # TODO: add a provider level status @json_schema_type diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f45975189..79707284c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -17,6 +17,8 @@ from llama_stack.apis.eval import ( ) from llama_stack.apis.inference import ( EmbeddingsResponse, + HealthResponse, + HealthStatus, Inference, LogProbConfig, Message, @@ -210,6 +212,17 @@ class InferenceRouter(Inference): contents=contents, ) + async def get_health(self) -> HealthResponse: + health_statuses = {} + for provider_id, impl in self.routing_table.impls_by_provider_id.items(): + try: + health_statuses[provider_id] = await impl.health() + except NotImplementedError: + health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.NOT_IMPLEMENTED}) + except Exception as e: + health_statuses[provider_id] = HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)}) + return health_statuses + class SafetyRouter(Safety): def __init__( diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index c79f97def..269d17d60 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, + HealthResponse, Inference, InterleavedContent, LogProbConfig, @@ -428,3 +429,8 @@ class MetaReferenceInferenceImpl( else: for x in impl(): yield x + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 6a83836e6..37531004f 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union from llama_stack.apis.inference import ( CompletionResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -75,3 +76,8 @@ class SentenceTransformersInferenceImpl( tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 5536ea3a5..b5d1beb69 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -230,5 +231,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() + async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index e896f0597..f14eb49f3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -198,3 +199,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): response_body = json.loads(response.get("body").read()) embeddings.append(response_body.get("embedding")) return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 1ce267e8d..fceff5e77 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -16,6 +16,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -191,3 +192,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 3d306e61f..106ca38e9 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -15,6 +15,7 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -140,3 +141,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index acf37b248..7119a77b5 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -297,3 +298,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 441b6af5c..6f0e93e28 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, InterleavedContent, LogProbConfig, @@ -154,3 +155,8 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD 'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "" }' ) return Groq(api_key=provider_data.groq_api_key) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 0c5b7c454..42ade17b7 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, InterleavedContent, LogProbConfig, @@ -201,3 +202,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f524c0734..e99342435 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -22,6 +22,8 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + HealthResponse, + HealthStatus, Inference, LogProbConfig, Message, @@ -369,6 +371,22 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return model + async def get_health(self) -> HealthResponse: + """ + Performs a health check by initializing the service. + + This method is used by the inspect API endpoint to verify that the service is running + correctly. + + Returns: + HealthResponse: A dictionary containing the health status. + """ + try: + await self.initialize() + return HealthResponse(health={"status": HealthStatus.OK}) + except ConnectionError as e: + return HealthResponse(health={"status": HealthStatus.ERROR, "message": str(e)}) + async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 1abb17336..f60ea192e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -125,3 +125,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b906e0dcb..5da55bd2b 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -326,3 +326,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): ] return compitable_tool_calls + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 1909e01f8..481768c40 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -18,6 +18,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -308,3 +309,8 @@ class InferenceEndpointAdapter(_HfAdapter): self.client = endpoint.async_client self.model_id = endpoint.repository self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 054501da8..84ff8166a 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -16,6 +16,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -274,3 +275,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi ) embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index b22284302..094294b7c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + HealthResponse, Inference, LogProbConfig, Message, @@ -375,3 +376,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def get_health( + self, + ) -> HealthResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/test_health.py b/llama_stack/providers/tests/inference/test_health.py new file mode 100644 index 000000000..7c92e11cb --- /dev/null +++ b/llama_stack/providers/tests/inference/test_health.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import HealthResponse + +# How to run this test: +# pytest -v -s llama_stack/providers/tests/inference/test_health.py + + +class TestHeatlh: + @pytest.mark.asyncio + async def test_health(self, inference_stack): + inference_impl, _ = inference_stack + response = await inference_impl.health() + for key in response: + assert isinstance(response[key], HealthResponse) + assert response[key].health["status"] == "OK", response