diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 665f8bd7e..c61712bfd 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -99,6 +99,17 @@ jobs: cat server.log exit 1 + - name: Verify Ollama status is OK + if: matrix.client-type == 'http' + run: | + echo "Verifying Ollama status..." + ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status) + echo "Ollama status: $ollama_status" + if [ "$ollama_status" != "OK" ]; then + echo "Ollama health check failed" + exit 1 + fi + - name: Run Integration Tests env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 542fb5be5..c85eb549f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7889,7 +7889,13 @@ "type": "object", "properties": { "status": { - "type": "string" + "type": "string", + "enum": [ + "OK", + "Error", + "Not Implemented" + ], + "title": "HealthStatus" } }, "additionalProperties": false, @@ -8084,6 +8090,31 @@ } ] } + }, + "health": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -8091,7 +8122,8 @@ "api", "provider_id", "provider_type", - "config" + "config", + "health" ], "title": "ProviderInfo" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index fa7b130e2..6c99c9155 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5463,6 +5463,11 @@ components: properties: status: type: string + enum: + - OK + - Error + - Not Implemented + title: HealthStatus additionalProperties: false required: - status @@ -5574,12 +5579,23 @@ components: - type: string - type: array - type: object + health: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - api - provider_id - provider_type - config + - health title: ProviderInfo InvokeToolRequest: type: object diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 3896d67a9..863f90e14 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable from pydantic import BaseModel +from llama_stack.providers.datatypes import HealthStatus from llama_stack.schema_utils import json_schema_type, webmethod @@ -20,8 +21,7 @@ class RouteInfo(BaseModel): @json_schema_type class HealthInfo(BaseModel): - status: str - # TODO: add a provider level status + status: HealthStatus @json_schema_type diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 83d03d7c1..ea5f968ec 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable from pydantic import BaseModel +from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -17,6 +18,7 @@ class ProviderInfo(BaseModel): provider_id: str provider_type: str config: Dict[str, Any] + health: HealthResponse class ListProvidersResponse(BaseModel): diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index ba0ce5ea2..23f644ec6 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -17,6 +17,7 @@ from llama_stack.apis.inspect import ( ) from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.server.endpoints import get_all_api_endpoints +from llama_stack.providers.datatypes import HealthStatus class DistributionInspectConfig(BaseModel): @@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect): return ListRoutesResponse(data=ret) async def health(self) -> HealthInfo: - return HealthInfo(status="OK") + return HealthInfo(status=HealthStatus.OK) async def version(self) -> VersionInfo: return VersionInfo(version=version("llama-stack")) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index c0143363d..f426bcafe 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.stack import ( construct_stack, get_stack_run_config_from_template, - redact_sensitive_fields, replace_env_vars, ) +from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index cf9b0b975..1c00ce264 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -4,14 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any, Dict from pydantic import BaseModel from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers from llama_stack.log import get_logger +from llama_stack.providers.datatypes import HealthResponse, HealthStatus from .datatypes import StackRunConfig -from .stack import redact_sensitive_fields +from .utils.config import redact_sensitive_fields logger = get_logger(name=__name__, category="core") @@ -41,19 +44,24 @@ class ProviderImpl(Providers): async def list_providers(self) -> ListProvidersResponse: run_config = self.config.run_config safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) + providers_health = await self.get_providers_health() ret = [] for api, providers in safe_config.providers.items(): - ret.extend( - [ + for p in providers: + ret.append( ProviderInfo( api=api, provider_id=p.provider_id, provider_type=p.provider_type, config=p.config, + health=providers_health.get(api, {}).get( + p.provider_id, + HealthResponse( + status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" + ), + ), ) - for p in providers - ] - ) + ) return ListProvidersResponse(data=ret) @@ -64,3 +72,57 @@ class ProviderImpl(Providers): return p raise ValueError(f"Provider {provider_id} not found") + + async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]: + """Get health status for all providers. + + Returns: + Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses. + Each API maps to a dictionary of provider IDs to their health responses. + """ + providers_health: Dict[str, Dict[str, HealthResponse]] = {} + timeout = 1.0 + + async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None: + # Skip special implementations (inspect/providers) that don't have provider specs + if not hasattr(impl, "__provider_spec__"): + return None + api_name = impl.__provider_spec__.api.name + if not hasattr(impl, "health"): + return ( + api_name, + HealthResponse( + status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" + ), + ) + + try: + health = await asyncio.wait_for(impl.health(), timeout=timeout) + return api_name, health + except asyncio.TimeoutError: + return ( + api_name, + HealthResponse( + status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds" + ), + ) + except Exception as e: + return ( + api_name, + HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"), + ) + + # Create tasks for all providers + tasks = [check_provider_health(impl) for impl in self.deps.values()] + + # Wait for all health checks to complete + results = await asyncio.gather(*tasks) + + # Organize results by API and provider ID + for result in results: + if result is None: # Skip special implementations + continue + api_name, health_response = result + providers_health[api_name] = health_response + + return providers_health diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0de1e0a02..e9a594eba 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import ( Api, BenchmarksProtocolPrivate, DatasetsProtocolPrivate, - InlineProviderSpec, ModelsProtocolPrivate, ProviderSpec, RemoteProviderConfig, @@ -230,46 +229,6 @@ def sort_providers_by_deps( {k: list(v.values()) for k, v in providers_with_specs.items()} ) - # Append built-in "inspect" provider - apis = [x[1].spec.api for x in sorted_providers] - sorted_providers.append( - ( - "inspect", - ProviderWithSpec( - provider_id="__builtin__", - provider_type="__builtin__", - config={"run_config": run_config.model_dump()}, - spec=InlineProviderSpec( - api=Api.inspect, - provider_type="__builtin__", - config_class="llama_stack.distribution.inspect.DistributionInspectConfig", - module="llama_stack.distribution.inspect", - api_dependencies=apis, - deps__=[x.value for x in apis], - ), - ), - ) - ) - - sorted_providers.append( - ( - "providers", - ProviderWithSpec( - provider_id="__builtin__", - provider_type="__builtin__", - config={"run_config": run_config.model_dump()}, - spec=InlineProviderSpec( - api=Api.providers, - provider_type="__builtin__", - config_class="llama_stack.distribution.providers.ProviderImplConfig", - module="llama_stack.distribution.providers", - api_dependencies=apis, - deps__=[x.value for x in apis], - ), - ), - ) - ) - logger.debug(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: logger.debug(f" {api_str} => {provider.provider_id}") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index b9623ef3c..cdf91e052 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union @@ -60,7 +61,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -580,6 +581,29 @@ class InferenceRouter(Inference): provider = self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_chat_completion(**params) + async def health(self) -> Dict[str, HealthResponse]: + health_statuses = {} + timeout = 0.5 + for provider_id, impl in self.routing_table.impls_by_provider_id.items(): + try: + # check if the provider has a health method + if not hasattr(impl, "health"): + continue + health = await asyncio.wait_for(impl.health(), timeout=timeout) + health_statuses[provider_id] = health + except asyncio.TimeoutError: + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.ERROR, + message=f"Health check timed out after {timeout} seconds", + ) + except NotImplementedError: + health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) + except Exception as e: + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" + ) + return health_statuses + class SafetyRouter(Safety): def __init__( diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7d4ec2a2f..d7ef37c26 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import ( ) from llama_stack.distribution.stack import ( construct_stack, - redact_sensitive_fields, replace_env_vars, validate_env_pair, ) +from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 08ff5e7cd..a6dc3d2a0 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl +from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -119,26 +121,6 @@ class EnvVarError(Exception): super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}") -def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: - """Redact sensitive information from config before printing.""" - sensitive_patterns = ["api_key", "api_token", "password", "secret"] - - def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]: - result = {} - for k, v in d.items(): - if isinstance(v, dict): - result[k] = _redact_dict(v) - elif isinstance(v, list): - result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v] - elif any(pattern in k.lower() for pattern in sensitive_patterns): - result[k] = "********" - else: - result[k] = v - return result - - return _redact_dict(data) - - def replace_env_vars(config: Any, path: str = "") -> Any: if isinstance(config, dict): result = {} @@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: ) from e +def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None: + """Add internal implementations (inspect and providers) to the implementations dictionary. + + Args: + impls: Dictionary of API implementations + run_config: Stack run configuration + """ + inspect_impl = DistributionInspectImpl( + DistributionInspectConfig(run_config=run_config), + deps=impls, + ) + impls[Api.inspect] = inspect_impl + + providers_impl = ProviderImpl( + ProviderImplConfig(run_config=run_config), + deps=impls, + ) + impls[Api.providers] = providers_impl + + # Produces a stack of providers for the given run config. Not all APIs may be # asked for in the run config. async def construct_stack( @@ -222,6 +224,10 @@ async def construct_stack( ) -> Dict[Api, Any]: dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) + + # Add internal implementations after all other providers are resolved + add_internal_implementations(impls, run_config) + await register_resources(run_config, impls) return impls diff --git a/llama_stack/distribution/utils/config.py b/llama_stack/distribution/utils/config.py new file mode 100644 index 000000000..5e78289b7 --- /dev/null +++ b/llama_stack/distribution/utils/config.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + + +def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: + """Redact sensitive information from config before printing.""" + sensitive_patterns = ["api_key", "api_token", "password", "secret"] + + def _redact_value(v: Any) -> Any: + if isinstance(v, dict): + return _redact_dict(v) + elif isinstance(v, list): + return [_redact_value(i) for i in v] + return v + + def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]: + result = {} + for k, v in d.items(): + if any(pattern in k.lower() for pattern in sensitive_patterns): + result[k] = "********" + else: + result[k] = _redact_value(v) + return result + + return _redact_dict(data) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 32dfba30c..c3141f807 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any, List, Optional, Protocol from urllib.parse import urlparse @@ -201,3 +202,12 @@ def remote_provider_spec( adapter=adapter, api_dependencies=api_dependencies or [], ) + + +class HealthStatus(str, Enum): + OK = "OK" + ERROR = "Error" + NOT_IMPLEMENTED = "Not Implemented" + + +HealthResponse = dict[str, Any] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 33b48af46..f84863385 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -42,7 +42,11 @@ from llama_stack.apis.inference import ( from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger -from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + ModelsProtocolPrivate, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -87,8 +91,19 @@ class OllamaInferenceAdapter( async def initialize(self) -> None: logger.info(f"checking connectivity to Ollama at `{self.url}`...") + await self.health() + + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the Ollama server. + This method is used by initialize() and the Provider API to verify that the service is running + correctly. + Returns: + HealthResponse: A dictionary containing the health status. + """ try: await self.client.ps() + return HealthResponse(status=HealthStatus.OK) except httpx.ConnectError as e: raise RuntimeError( "Ollama Server is not running, start it using `ollama serve` in a separate terminal"