forked from phoenix-oss/llama-stack-mirror
feat: add health to all providers through providers endpoint (#1418)
The `/v1/providers` now reports the health status of each provider when implemented. ``` curl -L http://127.0.0.1:8321/v1/providers|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 4072 100 4072 0 0 246k 0 --:--:-- --:--:-- --:--:-- 248k { "data": [ { "api": "inference", "provider_id": "ollama", "provider_type": "remote::ollama", "config": { "url": "http://localhost:11434" }, "health": { "status": "OK" } }, { "api": "vector_io", "provider_id": "faiss", "provider_type": "inline::faiss", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/faiss_store.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "safety", "provider_id": "llama-guard", "provider_type": "inline::llama-guard", "config": { "excluded_categories": [] }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "agents", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "persistence_store": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/agents_store.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "telemetry", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "service_name": "llama-stack", "sinks": "console,sqlite", "sqlite_db_path": "/Users/leseb/.llama/distributions/ollama/trace_store.db" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "eval", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/meta_reference_eval.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "datasetio", "provider_id": "huggingface", "provider_type": "remote::huggingface", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/huggingface_datasetio.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "datasetio", "provider_id": "localfs", "provider_type": "inline::localfs", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/localfs_datasetio.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "basic", "provider_type": "inline::basic", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "llm-as-judge", "provider_type": "inline::llm-as-judge", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "braintrust", "provider_type": "inline::braintrust", "config": { "openai_api_key": "********" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "brave-search", "provider_type": "remote::brave-search", "config": { "api_key": "********", "max_results": 3 }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "tavily-search", "provider_type": "remote::tavily-search", "config": { "api_key": "********", "max_results": 3 }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "code-interpreter", "provider_type": "inline::code-interpreter", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "rag-runtime", "provider_type": "inline::rag-runtime", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "model-context-protocol", "provider_type": "remote::model-context-protocol", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "wolfram-alpha", "provider_type": "remote::wolfram-alpha", "config": { "api_key": "********" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } } ] } ``` Per providers too: ``` curl -L http://127.0.0.1:8321/v1/providers/ollama {"api":"inference","provider_id":"ollama","provider_type":"remote::ollama","config":{"url":"http://localhost:11434"},"health":{"status":"OK"}} ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
ff14773fa7
commit
69554158fa
15 changed files with 244 additions and 76 deletions
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
|
@ -99,6 +99,17 @@ jobs:
|
||||||
cat server.log
|
cat server.log
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
|
- name: Verify Ollama status is OK
|
||||||
|
if: matrix.client-type == 'http'
|
||||||
|
run: |
|
||||||
|
echo "Verifying Ollama status..."
|
||||||
|
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
|
||||||
|
echo "Ollama status: $ollama_status"
|
||||||
|
if [ "$ollama_status" != "OK" ]; then
|
||||||
|
echo "Ollama health check failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
|
36
docs/_static/llama-stack-spec.html
vendored
36
docs/_static/llama-stack-spec.html
vendored
|
@ -7889,7 +7889,13 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"status": {
|
"status": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"OK",
|
||||||
|
"Error",
|
||||||
|
"Not Implemented"
|
||||||
|
],
|
||||||
|
"title": "HealthStatus"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -8084,6 +8090,31 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"health": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -8091,7 +8122,8 @@
|
||||||
"api",
|
"api",
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"provider_type",
|
"provider_type",
|
||||||
"config"
|
"config",
|
||||||
|
"health"
|
||||||
],
|
],
|
||||||
"title": "ProviderInfo"
|
"title": "ProviderInfo"
|
||||||
},
|
},
|
||||||
|
|
16
docs/_static/llama-stack-spec.yaml
vendored
16
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5463,6 +5463,11 @@ components:
|
||||||
properties:
|
properties:
|
||||||
status:
|
status:
|
||||||
type: string
|
type: string
|
||||||
|
enum:
|
||||||
|
- OK
|
||||||
|
- Error
|
||||||
|
- Not Implemented
|
||||||
|
title: HealthStatus
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- status
|
- status
|
||||||
|
@ -5574,12 +5579,23 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
health:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- api
|
- api
|
||||||
- provider_id
|
- provider_id
|
||||||
- provider_type
|
- provider_type
|
||||||
- config
|
- config
|
||||||
|
- health
|
||||||
title: ProviderInfo
|
title: ProviderInfo
|
||||||
InvokeToolRequest:
|
InvokeToolRequest:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class HealthInfo(BaseModel):
|
class HealthInfo(BaseModel):
|
||||||
status: str
|
status: HealthStatus
|
||||||
# TODO: add a provider level status
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import HealthResponse
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: Dict[str, Any]
|
||||||
|
health: HealthResponse
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectConfig(BaseModel):
|
class DistributionInspectConfig(BaseModel):
|
||||||
|
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
return ListRoutesResponse(data=ret)
|
return ListRoutesResponse(data=ret)
|
||||||
|
|
||||||
async def health(self) -> HealthInfo:
|
async def health(self) -> HealthInfo:
|
||||||
return HealthInfo(status="OK")
|
return HealthInfo(status=HealthStatus.OK)
|
||||||
|
|
||||||
async def version(self) -> VersionInfo:
|
async def version(self) -> VersionInfo:
|
||||||
return VersionInfo(version=version("llama-stack"))
|
return VersionInfo(version=version("llama-stack"))
|
||||||
|
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
redact_sensitive_fields,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.distribution.utils.exec import in_notebook
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
|
|
@ -4,14 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||||
|
|
||||||
from .datatypes import StackRunConfig
|
from .datatypes import StackRunConfig
|
||||||
from .stack import redact_sensitive_fields
|
from .utils.config import redact_sensitive_fields
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
run_config = self.config.run_config
|
run_config = self.config.run_config
|
||||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||||
|
providers_health = await self.get_providers_health()
|
||||||
ret = []
|
ret = []
|
||||||
for api, providers in safe_config.providers.items():
|
for api, providers in safe_config.providers.items():
|
||||||
ret.extend(
|
for p in providers:
|
||||||
[
|
ret.append(
|
||||||
ProviderInfo(
|
ProviderInfo(
|
||||||
api=api,
|
api=api,
|
||||||
provider_id=p.provider_id,
|
provider_id=p.provider_id,
|
||||||
provider_type=p.provider_type,
|
provider_type=p.provider_type,
|
||||||
config=p.config,
|
config=p.config,
|
||||||
|
health=providers_health.get(api, {}).get(
|
||||||
|
p.provider_id,
|
||||||
|
HealthResponse(
|
||||||
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for p in providers
|
)
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return ListProvidersResponse(data=ret)
|
return ListProvidersResponse(data=ret)
|
||||||
|
|
||||||
|
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
|
||||||
return p
|
return p
|
||||||
|
|
||||||
raise ValueError(f"Provider {provider_id} not found")
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||||
|
"""Get health status for all providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||||
|
Each API maps to a dictionary of provider IDs to their health responses.
|
||||||
|
"""
|
||||||
|
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||||
|
timeout = 1.0
|
||||||
|
|
||||||
|
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||||
|
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||||
|
if not hasattr(impl, "__provider_spec__"):
|
||||||
|
return None
|
||||||
|
api_name = impl.__provider_spec__.api.name
|
||||||
|
if not hasattr(impl, "health"):
|
||||||
|
return (
|
||||||
|
api_name,
|
||||||
|
HealthResponse(
|
||||||
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||||
|
return api_name, health
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return (
|
||||||
|
api_name,
|
||||||
|
HealthResponse(
|
||||||
|
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return (
|
||||||
|
api_name,
|
||||||
|
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tasks for all providers
|
||||||
|
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||||
|
|
||||||
|
# Wait for all health checks to complete
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Organize results by API and provider ID
|
||||||
|
for result in results:
|
||||||
|
if result is None: # Skip special implementations
|
||||||
|
continue
|
||||||
|
api_name, health_response = result
|
||||||
|
providers_health[api_name] = health_response
|
||||||
|
|
||||||
|
return providers_health
|
||||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
BenchmarksProtocolPrivate,
|
BenchmarksProtocolPrivate,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
InlineProviderSpec,
|
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
RemoteProviderConfig,
|
RemoteProviderConfig,
|
||||||
|
@ -230,46 +229,6 @@ def sort_providers_by_deps(
|
||||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append built-in "inspect" provider
|
|
||||||
apis = [x[1].spec.api for x in sorted_providers]
|
|
||||||
sorted_providers.append(
|
|
||||||
(
|
|
||||||
"inspect",
|
|
||||||
ProviderWithSpec(
|
|
||||||
provider_id="__builtin__",
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config={"run_config": run_config.model_dump()},
|
|
||||||
spec=InlineProviderSpec(
|
|
||||||
api=Api.inspect,
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
|
||||||
module="llama_stack.distribution.inspect",
|
|
||||||
api_dependencies=apis,
|
|
||||||
deps__=[x.value for x in apis],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_providers.append(
|
|
||||||
(
|
|
||||||
"providers",
|
|
||||||
ProviderWithSpec(
|
|
||||||
provider_id="__builtin__",
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config={"run_config": run_config.model_dump()},
|
|
||||||
spec=InlineProviderSpec(
|
|
||||||
api=Api.providers,
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
|
||||||
module="llama_stack.distribution.providers",
|
|
||||||
api_dependencies=apis,
|
|
||||||
deps__=[x.value for x in apis],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
@ -60,7 +61,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
@ -580,6 +581,29 @@ class InferenceRouter(Inference):
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
return await provider.openai_chat_completion(**params)
|
return await provider.openai_chat_completion(**params)
|
||||||
|
|
||||||
|
async def health(self) -> Dict[str, HealthResponse]:
|
||||||
|
health_statuses = {}
|
||||||
|
timeout = 0.5
|
||||||
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||||
|
try:
|
||||||
|
# check if the provider has a health method
|
||||||
|
if not hasattr(impl, "health"):
|
||||||
|
continue
|
||||||
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||||
|
health_statuses[provider_id] = health
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
health_statuses[provider_id] = HealthResponse(
|
||||||
|
status=HealthStatus.ERROR,
|
||||||
|
message=f"Health check timed out after {timeout} seconds",
|
||||||
|
)
|
||||||
|
except NotImplementedError:
|
||||||
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||||
|
except Exception as e:
|
||||||
|
health_statuses[provider_id] = HealthResponse(
|
||||||
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
|
)
|
||||||
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
redact_sensitive_fields,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
|
@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
|
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -119,26 +121,6 @@ class EnvVarError(Exception):
|
||||||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||||
|
|
||||||
|
|
||||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Redact sensitive information from config before printing."""
|
|
||||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
|
||||||
|
|
||||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
result = {}
|
|
||||||
for k, v in d.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
result[k] = _redact_dict(v)
|
|
||||||
elif isinstance(v, list):
|
|
||||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
|
||||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
|
||||||
result[k] = "********"
|
|
||||||
else:
|
|
||||||
result[k] = v
|
|
||||||
return result
|
|
||||||
|
|
||||||
return _redact_dict(data)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
result = {}
|
result = {}
|
||||||
|
@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||||
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
impls: Dictionary of API implementations
|
||||||
|
run_config: Stack run configuration
|
||||||
|
"""
|
||||||
|
inspect_impl = DistributionInspectImpl(
|
||||||
|
DistributionInspectConfig(run_config=run_config),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
impls[Api.inspect] = inspect_impl
|
||||||
|
|
||||||
|
providers_impl = ProviderImpl(
|
||||||
|
ProviderImplConfig(run_config=run_config),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
impls[Api.providers] = providers_impl
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
|
@ -222,6 +224,10 @@ async def construct_stack(
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||||
|
|
||||||
|
# Add internal implementations after all other providers are resolved
|
||||||
|
add_internal_implementations(impls, run_config)
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
30
llama_stack/distribution/utils/config.py
Normal file
30
llama_stack/distribution/utils/config.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Redact sensitive information from config before printing."""
|
||||||
|
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||||
|
|
||||||
|
def _redact_value(v: Any) -> Any:
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return _redact_dict(v)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
return [_redact_value(i) for i in v]
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||||
|
result[k] = "********"
|
||||||
|
else:
|
||||||
|
result[k] = _redact_value(v)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _redact_dict(data)
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, List, Optional, Protocol
|
from typing import Any, List, Optional, Protocol
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -201,3 +202,12 @@ def remote_provider_spec(
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
api_dependencies=api_dependencies or [],
|
api_dependencies=api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthStatus(str, Enum):
|
||||||
|
OK = "OK"
|
||||||
|
ERROR = "Error"
|
||||||
|
NOT_IMPLEMENTED = "Not Implemented"
|
||||||
|
|
||||||
|
|
||||||
|
HealthResponse = dict[str, Any]
|
||||||
|
|
|
@ -42,7 +42,11 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import (
|
||||||
|
HealthResponse,
|
||||||
|
HealthStatus,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -87,8 +91,19 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||||
|
await self.health()
|
||||||
|
|
||||||
|
async def health(self) -> HealthResponse:
|
||||||
|
"""
|
||||||
|
Performs a health check by verifying connectivity to the Ollama server.
|
||||||
|
This method is used by initialize() and the Provider API to verify that the service is running
|
||||||
|
correctly.
|
||||||
|
Returns:
|
||||||
|
HealthResponse: A dictionary containing the health status.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await self.client.ps()
|
await self.client.ps()
|
||||||
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue