forked from phoenix-oss/llama-stack-mirror
feat: add health to all providers through providers endpoint (#1418)
The `/v1/providers` now reports the health status of each provider when implemented. ``` curl -L http://127.0.0.1:8321/v1/providers|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 4072 100 4072 0 0 246k 0 --:--:-- --:--:-- --:--:-- 248k { "data": [ { "api": "inference", "provider_id": "ollama", "provider_type": "remote::ollama", "config": { "url": "http://localhost:11434" }, "health": { "status": "OK" } }, { "api": "vector_io", "provider_id": "faiss", "provider_type": "inline::faiss", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/faiss_store.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "safety", "provider_id": "llama-guard", "provider_type": "inline::llama-guard", "config": { "excluded_categories": [] }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "agents", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "persistence_store": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/agents_store.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "telemetry", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "service_name": "llama-stack", "sinks": "console,sqlite", "sqlite_db_path": "/Users/leseb/.llama/distributions/ollama/trace_store.db" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "eval", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/meta_reference_eval.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "datasetio", "provider_id": "huggingface", "provider_type": "remote::huggingface", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/huggingface_datasetio.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "datasetio", "provider_id": "localfs", "provider_type": "inline::localfs", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/localfs_datasetio.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "basic", "provider_type": "inline::basic", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "llm-as-judge", "provider_type": "inline::llm-as-judge", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "braintrust", "provider_type": "inline::braintrust", "config": { "openai_api_key": "********" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "brave-search", "provider_type": "remote::brave-search", "config": { "api_key": "********", "max_results": 3 }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "tavily-search", "provider_type": "remote::tavily-search", "config": { "api_key": "********", "max_results": 3 }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "code-interpreter", "provider_type": "inline::code-interpreter", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "rag-runtime", "provider_type": "inline::rag-runtime", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "model-context-protocol", "provider_type": "remote::model-context-protocol", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "wolfram-alpha", "provider_type": "remote::wolfram-alpha", "config": { "api_key": "********" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } } ] } ``` Per providers too: ``` curl -L http://127.0.0.1:8321/v1/providers/ollama {"api":"inference","provider_id":"ollama","provider_type":"remote::ollama","config":{"url":"http://localhost:11434"},"health":{"status":"OK"}} ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
ff14773fa7
commit
69554158fa
15 changed files with 244 additions and 76 deletions
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
|
@ -99,6 +99,17 @@ jobs:
|
|||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: Verify Ollama status is OK
|
||||
if: matrix.client-type == 'http'
|
||||
run: |
|
||||
echo "Verifying Ollama status..."
|
||||
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
|
||||
echo "Ollama status: $ollama_status"
|
||||
if [ "$ollama_status" != "OK" ]; then
|
||||
echo "Ollama health check failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Run Integration Tests
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
|
36
docs/_static/llama-stack-spec.html
vendored
36
docs/_static/llama-stack-spec.html
vendored
|
@ -7889,7 +7889,13 @@
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"OK",
|
||||
"Error",
|
||||
"Not Implemented"
|
||||
],
|
||||
"title": "HealthStatus"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -8084,6 +8090,31 @@
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"health": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -8091,7 +8122,8 @@
|
|||
"api",
|
||||
"provider_id",
|
||||
"provider_type",
|
||||
"config"
|
||||
"config",
|
||||
"health"
|
||||
],
|
||||
"title": "ProviderInfo"
|
||||
},
|
||||
|
|
16
docs/_static/llama-stack-spec.yaml
vendored
16
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5463,6 +5463,11 @@ components:
|
|||
properties:
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- OK
|
||||
- Error
|
||||
- Not Implemented
|
||||
title: HealthStatus
|
||||
additionalProperties: false
|
||||
required:
|
||||
- status
|
||||
|
@ -5574,12 +5579,23 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
health:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
- config
|
||||
- health
|
||||
title: ProviderInfo
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
status: str
|
||||
# TODO: add a provider level status
|
||||
status: HealthStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
|
|||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
|
|||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
||||
return HealthInfo(status=HealthStatus.OK)
|
||||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
|
|||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
|
|
|
@ -4,14 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
|
|||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
for p in providers:
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
),
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
|
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
|
|||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||
timeout = 1.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||
if not hasattr(impl, "__provider_spec__"):
|
||||
return None
|
||||
api_name = impl.__provider_spec__.api.name
|
||||
if not hasattr(impl, "health"):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||
)
|
||||
|
||||
# Create tasks for all providers
|
||||
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||
|
||||
# Wait for all health checks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Organize results by API and provider ID
|
||||
for result in results:
|
||||
if result is None: # Skip special implementations
|
||||
continue
|
||||
api_name, health_response = result
|
||||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
|
|||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
|
@ -230,46 +229,6 @@ def sort_providers_by_deps(
|
|||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
|
@ -60,7 +61,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -580,6 +581,29 @@ class InferenceRouter(Inference):
|
|||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_chat_completion(**params)
|
||||
|
||||
async def health(self) -> Dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 0.5
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
if not hasattr(impl, "health"):
|
||||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except asyncio.TimeoutError:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
|
|
|
@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
|
|||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
|
@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
@ -119,26 +121,6 @@ class EnvVarError(Exception):
|
|||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
|
@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
impls: Dictionary of API implementations
|
||||
run_config: Stack run configuration
|
||||
"""
|
||||
inspect_impl = DistributionInspectImpl(
|
||||
DistributionInspectConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.inspect] = inspect_impl
|
||||
|
||||
providers_impl = ProviderImpl(
|
||||
ProviderImplConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
|
@ -222,6 +224,10 @@ async def construct_stack(
|
|||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
|
30
llama_stack/distribution/utils/config.py
Normal file
30
llama_stack/distribution/utils/config.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_value(v: Any) -> Any:
|
||||
if isinstance(v, dict):
|
||||
return _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
return [_redact_value(i) for i in v]
|
||||
return v
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = _redact_value(v)
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -201,3 +202,12 @@ def remote_provider_spec(
|
|||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
OK = "OK"
|
||||
ERROR = "Error"
|
||||
NOT_IMPLEMENTED = "Not Implemented"
|
||||
|
||||
|
||||
HealthResponse = dict[str, Any]
|
||||
|
|
|
@ -42,7 +42,11 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -87,8 +91,19 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
await self.health()
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
This method is used by initialize() and the Provider API to verify that the service is running
|
||||
correctly.
|
||||
Returns:
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except httpx.ConnectError as e:
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue