feat: add health to all providers through providers endpoint (#1418)

The `/v1/providers` now reports the health status of each
provider when implemented.

```
curl -L http://127.0.0.1:8321/v1/providers|jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  4072  100  4072    0     0   246k      0 --:--:-- --:--:-- --:--:--  248k
{
  "data": [
    {
      "api": "inference",
      "provider_id": "ollama",
      "provider_type": "remote::ollama",
      "config": {
        "url": "http://localhost:11434"
      },
      "health": {
        "status": "OK"
      }
    },
    {
      "api": "vector_io",
      "provider_id": "faiss",
      "provider_type": "inline::faiss",
      "config": {
        "kvstore": {
          "type": "sqlite",
          "namespace": null,
          "db_path": "/Users/leseb/.llama/distributions/ollama/faiss_store.db"
        }
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "safety",
      "provider_id": "llama-guard",
      "provider_type": "inline::llama-guard",
      "config": {
        "excluded_categories": []
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "agents",
      "provider_id": "meta-reference",
      "provider_type": "inline::meta-reference",
      "config": {
        "persistence_store": {
          "type": "sqlite",
          "namespace": null,
          "db_path": "/Users/leseb/.llama/distributions/ollama/agents_store.db"
        }
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "telemetry",
      "provider_id": "meta-reference",
      "provider_type": "inline::meta-reference",
      "config": {
        "service_name": "llama-stack",
        "sinks": "console,sqlite",
        "sqlite_db_path": "/Users/leseb/.llama/distributions/ollama/trace_store.db"
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "eval",
      "provider_id": "meta-reference",
      "provider_type": "inline::meta-reference",
      "config": {
        "kvstore": {
          "type": "sqlite",
          "namespace": null,
          "db_path": "/Users/leseb/.llama/distributions/ollama/meta_reference_eval.db"
        }
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "datasetio",
      "provider_id": "huggingface",
      "provider_type": "remote::huggingface",
      "config": {
        "kvstore": {
          "type": "sqlite",
          "namespace": null,
          "db_path": "/Users/leseb/.llama/distributions/ollama/huggingface_datasetio.db"
        }
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "datasetio",
      "provider_id": "localfs",
      "provider_type": "inline::localfs",
      "config": {
        "kvstore": {
          "type": "sqlite",
          "namespace": null,
          "db_path": "/Users/leseb/.llama/distributions/ollama/localfs_datasetio.db"
        }
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "scoring",
      "provider_id": "basic",
      "provider_type": "inline::basic",
      "config": {},
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "scoring",
      "provider_id": "llm-as-judge",
      "provider_type": "inline::llm-as-judge",
      "config": {},
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "scoring",
      "provider_id": "braintrust",
      "provider_type": "inline::braintrust",
      "config": {
        "openai_api_key": "********"
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "tool_runtime",
      "provider_id": "brave-search",
      "provider_type": "remote::brave-search",
      "config": {
        "api_key": "********",
        "max_results": 3
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "tool_runtime",
      "provider_id": "tavily-search",
      "provider_type": "remote::tavily-search",
      "config": {
        "api_key": "********",
        "max_results": 3
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "tool_runtime",
      "provider_id": "code-interpreter",
      "provider_type": "inline::code-interpreter",
      "config": {},
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "tool_runtime",
      "provider_id": "rag-runtime",
      "provider_type": "inline::rag-runtime",
      "config": {},
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "tool_runtime",
      "provider_id": "model-context-protocol",
      "provider_type": "remote::model-context-protocol",
      "config": {},
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    },
    {
      "api": "tool_runtime",
      "provider_id": "wolfram-alpha",
      "provider_type": "remote::wolfram-alpha",
      "config": {
        "api_key": "********"
      },
      "health": {
        "status": "Not Implemented",
        "message": "Provider does not implement health check"
      }
    }
  ]
}
```

Per providers too:

```
curl -L http://127.0.0.1:8321/v1/providers/ollama
{"api":"inference","provider_id":"ollama","provider_type":"remote::ollama","config":{"url":"http://localhost:11434"},"health":{"status":"OK"}}
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-04-14 11:59:36 +02:00 committed by GitHub
parent ff14773fa7
commit 69554158fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 244 additions and 76 deletions

View file

@ -99,6 +99,17 @@ jobs:
cat server.log cat server.log
exit 1 exit 1
- name: Verify Ollama status is OK
if: matrix.client-type == 'http'
run: |
echo "Verifying Ollama status..."
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
echo "Ollama status: $ollama_status"
if [ "$ollama_status" != "OK" ]; then
echo "Ollama health check failed"
exit 1
fi
- name: Run Integration Tests - name: Run Integration Tests
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"

View file

@ -7889,7 +7889,13 @@
"type": "object", "type": "object",
"properties": { "properties": {
"status": { "status": {
"type": "string" "type": "string",
"enum": [
"OK",
"Error",
"Not Implemented"
],
"title": "HealthStatus"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -8084,6 +8090,31 @@
} }
] ]
} }
},
"health": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -8091,7 +8122,8 @@
"api", "api",
"provider_id", "provider_id",
"provider_type", "provider_type",
"config" "config",
"health"
], ],
"title": "ProviderInfo" "title": "ProviderInfo"
}, },

View file

@ -5463,6 +5463,11 @@ components:
properties: properties:
status: status:
type: string type: string
enum:
- OK
- Error
- Not Implemented
title: HealthStatus
additionalProperties: false additionalProperties: false
required: required:
- status - status
@ -5574,12 +5579,23 @@ components:
- type: string - type: string
- type: array - type: array
- type: object - type: object
health:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false additionalProperties: false
required: required:
- api - api
- provider_id - provider_id
- provider_type - provider_type
- config - config
- health
title: ProviderInfo title: ProviderInfo
InvokeToolRequest: InvokeToolRequest:
type: object type: object

View file

@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
@json_schema_type @json_schema_type
class HealthInfo(BaseModel): class HealthInfo(BaseModel):
status: str status: HealthStatus
# TODO: add a provider level status
@json_schema_type @json_schema_type

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
provider_id: str provider_id: str
provider_type: str provider_type: str
config: Dict[str, Any] config: Dict[str, Any]
health: HealthResponse
class ListProvidersResponse(BaseModel): class ListProvidersResponse(BaseModel):

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
) )
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel): class DistributionInspectConfig(BaseModel):
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
return ListRoutesResponse(data=ret) return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo: async def health(self) -> HealthInfo:
return HealthInfo(status="OK") return HealthInfo(status=HealthStatus.OK)
async def version(self) -> VersionInfo: async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack")) return VersionInfo(version=version("llama-stack"))

View file

@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
async def list_providers(self) -> ListProvidersResponse: async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = [] ret = []
for api, providers in safe_config.providers.items(): for api, providers in safe_config.providers.items():
ret.extend( for p in providers:
[ ret.append(
ProviderInfo( ProviderInfo(
api=api, api=api,
provider_id=p.provider_id, provider_id=p.provider_id,
provider_type=p.provider_type, provider_type=p.provider_type,
config=p.config, config=p.config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
),
) )
for p in providers )
]
)
return ListProvidersResponse(data=ret) return ListProvidersResponse(data=ret)
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
return p return p
raise ValueError(f"Provider {provider_id} not found") raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
timeout = 1.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
# Skip special implementations (inspect/providers) that don't have provider specs
if not hasattr(impl, "__provider_spec__"):
return None
api_name = impl.__provider_spec__.api.name
if not hasattr(impl, "health"):
return (
api_name,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except asyncio.TimeoutError:
return (
api_name,
HealthResponse(
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
),
)
except Exception as e:
return (
api_name,
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
)
# Create tasks for all providers
tasks = [check_provider_health(impl) for impl in self.deps.values()]
# Wait for all health checks to complete
results = await asyncio.gather(*tasks)
# Organize results by API and provider ID
for result in results:
if result is None: # Skip special implementations
continue
api_name, health_response = result
providers_health[api_name] = health_response
return providers_health

View file

@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
Api, Api,
BenchmarksProtocolPrivate, BenchmarksProtocolPrivate,
DatasetsProtocolPrivate, DatasetsProtocolPrivate,
InlineProviderSpec,
ModelsProtocolPrivate, ModelsProtocolPrivate,
ProviderSpec, ProviderSpec,
RemoteProviderConfig, RemoteProviderConfig,
@ -230,46 +229,6 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()} {k: list(v.values()) for k, v in providers_with_specs.items()}
) )
# Append built-in "inspect" provider
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
logger.debug(f"Resolved {len(sorted_providers)} providers") logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}") logger.debug(f" {api_str} => {provider.provider_id}")

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import time import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
@ -60,7 +61,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -580,6 +581,29 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params) return await provider.openai_chat_completion(**params)
async def health(self) -> Dict[str, HealthResponse]:
health_statuses = {}
timeout = 0.5
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except asyncio.TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
class SafetyRouter(Safety): class SafetyRouter(Safety):
def __init__( def __init__(

View file

@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
) )
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api

View file

@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -119,26 +121,6 @@ class EnvVarError(Exception):
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}") super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any: def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict): if isinstance(config, dict):
result = {} result = {}
@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e ) from e
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
# Produces a stack of providers for the given run config. Not all APIs may be # Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config. # asked for in the run config.
async def construct_stack( async def construct_stack(
@ -222,6 +224,10 @@ async def construct_stack(
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_value(v: Any) -> Any:
if isinstance(v, dict):
return _redact_dict(v)
elif isinstance(v, list):
return [_redact_value(i) for i in v]
return v
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = _redact_value(v)
return result
return _redact_dict(data)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, List, Optional, Protocol from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse from urllib.parse import urlparse
@ -201,3 +202,12 @@ def remote_provider_spec(
adapter=adapter, adapter=adapter,
api_dependencies=api_dependencies or [], api_dependencies=api_dependencies or [],
) )
class HealthStatus(str, Enum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"
HealthResponse = dict[str, Any]

View file

@ -42,7 +42,11 @@ from llama_stack.apis.inference import (
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -87,8 +91,19 @@ class OllamaInferenceAdapter(
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.url}`...") logger.info(f"checking connectivity to Ollama at `{self.url}`...")
await self.health()
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
This method is used by initialize() and the Provider API to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try: try:
await self.client.ps() await self.client.ps()
return HealthResponse(status=HealthStatus.OK)
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise RuntimeError( raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal" "Ollama Server is not running, start it using `ollama serve` in a separate terminal"