mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: rename StackRunConfig to StackConfig
since this object represents our config for list-deps, run, etc lets rename it to simply `StackConfig` Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
17f8ab31b5
commit
4a3f9151e3
23 changed files with 72 additions and 72 deletions
|
|
@ -12,7 +12,7 @@ import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.core.build import get_provider_dependencies
|
from llama_stack.core.build import get_provider_dependencies
|
||||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
from llama_stack.core.datatypes import Provider, StackConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import Api
|
from llama_stack_api import Api
|
||||||
|
|
@ -78,7 +78,7 @@ def run_stack_list_deps_command(args: argparse.Namespace) -> None:
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
try:
|
try:
|
||||||
contents = yaml.safe_load(f)
|
contents = yaml.safe_load(f)
|
||||||
run_config = StackRunConfig(**contents)
|
run_config = StackConfig(**contents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not parse config file {config_file}: {e}",
|
f"Could not parse config file {config_file}: {e}",
|
||||||
|
|
@ -119,7 +119,7 @@ def run_stack_list_deps_command(args: argparse.Namespace) -> None:
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
run_config = StackRunConfig(providers=provider_list, image_name="providers-run")
|
run_config = StackConfig(providers=provider_list, image_name="providers-run")
|
||||||
|
|
||||||
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(run_config)
|
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(run_config)
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.storage.datatypes import (
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
|
@ -156,7 +156,7 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
# Write config to disk in providers-run directory
|
# Write config to disk in providers-run directory
|
||||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
config_file = distro_dir / "run.yaml"
|
config_file = distro_dir / "config.yaml"
|
||||||
|
|
||||||
logger.info(f"Writing generated config to: {config_file}")
|
logger.info(f"Writing generated config to: {config_file}")
|
||||||
with open(config_file, "w") as f:
|
with open(config_file, "w") as f:
|
||||||
|
|
@ -194,7 +194,7 @@ class StackRun(Subcommand):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
else:
|
else:
|
||||||
logger_config = None
|
logger_config = None
|
||||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
config = StackConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
port = args.port or config.server.port
|
port = args.port or config.server.port
|
||||||
host = config.server.host or "0.0.0.0"
|
host = config.server.host or "0.0.0.0"
|
||||||
|
|
@ -318,7 +318,7 @@ class StackRun(Subcommand):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return StackRunConfig(
|
return StackConfig(
|
||||||
image_name="providers-run",
|
image_name="providers-run",
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from termcolor import cprint
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
BuildConfig,
|
BuildConfig,
|
||||||
Provider,
|
Provider,
|
||||||
StackRunConfig,
|
StackConfig,
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
|
|
@ -61,7 +61,7 @@ def generate_run_config(
|
||||||
"""
|
"""
|
||||||
apis = list(build_config.distribution_spec.providers.keys())
|
apis = list(build_config.distribution_spec.providers.keys())
|
||||||
distro_dir = DISTRIBS_BASE_DIR / image_name
|
distro_dir = DISTRIBS_BASE_DIR / image_name
|
||||||
run_config = StackRunConfig(
|
run_config = StackConfig(
|
||||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis=apis,
|
apis=apis,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import sys
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.distributions.template import DistributionTemplate
|
from llama_stack.distributions.template import DistributionTemplate
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -36,7 +36,7 @@ class ApiInput(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def get_provider_dependencies(
|
def get_provider_dependencies(
|
||||||
config: StackRunConfig,
|
config: StackConfig,
|
||||||
) -> tuple[list[str], list[str], list[str]]:
|
) -> tuple[list[str], list[str], list[str]]:
|
||||||
"""Get normal and special dependencies from provider configuration."""
|
"""Get normal and special dependencies from provider configuration."""
|
||||||
if isinstance(config, DistributionTemplate):
|
if isinstance(config, DistributionTemplate):
|
||||||
|
|
@ -83,7 +83,7 @@ def get_provider_dependencies(
|
||||||
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
||||||
|
|
||||||
|
|
||||||
def print_pip_install_help(config: StackRunConfig):
|
def print_pip_install_help(config: StackConfig):
|
||||||
normal_deps, special_deps, _ = get_provider_dependencies(config)
|
normal_deps, special_deps, _ = get_provider_dependencies(config)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.core.datatypes import (
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
Provider,
|
Provider,
|
||||||
StackRunConfig,
|
StackConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import (
|
from llama_stack.core.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
|
@ -44,7 +44,7 @@ def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provi
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
def configure_api_providers(config: StackConfig, build_spec: DistributionSpec) -> StackConfig:
|
||||||
is_nux = len(config.providers) == 0
|
is_nux = len(config.providers) == 0
|
||||||
|
|
||||||
if is_nux:
|
if is_nux:
|
||||||
|
|
@ -192,7 +192,7 @@ def upgrade_from_routing_table(
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackConfig:
|
||||||
if "routing_table" in config_dict:
|
if "routing_table" in config_dict:
|
||||||
logger.info("Upgrading config...")
|
logger.info("Upgrading config...")
|
||||||
config_dict = upgrade_from_routing_table(config_dict)
|
config_dict = upgrade_from_routing_table(config_dict)
|
||||||
|
|
@ -200,4 +200,4 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
processed_config_dict = replace_env_vars(config_dict)
|
processed_config_dict = replace_env_vars(config_dict)
|
||||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
return StackConfig(**cast_image_name_to_string(processed_config_dict))
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
from llama_stack.core.datatypes import AccessRule, StackConfig
|
||||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -36,7 +36,7 @@ class ConversationServiceConfig(BaseModel):
|
||||||
:param policy: Access control rules
|
:param policy: Access control rules
|
||||||
"""
|
"""
|
||||||
|
|
||||||
run_config: StackRunConfig
|
run_config: StackConfig
|
||||||
policy: list[AccessRule] = []
|
policy: list[AccessRule] = []
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -490,7 +490,7 @@ class ServerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackConfig(BaseModel):
|
||||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
image_name: str = Field(
|
image_name: str = Field(
|
||||||
|
|
@ -565,7 +565,7 @@ can be instantiated multiple times (with different configs) if necessary.
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_server_stores(self) -> "StackRunConfig":
|
def validate_server_stores(self) -> "StackConfig":
|
||||||
backend_map = self.storage.backends
|
backend_map = self.storage.backends
|
||||||
stores = self.storage.stores
|
stores = self.storage.stores
|
||||||
kv_backends = {
|
kv_backends = {
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,14 @@
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
from llama_stack.core.datatypes import BuildConfig, StackConfig
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import Api, ExternalApiSpec
|
from llama_stack_api import Api, ExternalApiSpec
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
def load_external_apis(config: StackConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||||
"""Load external API specifications from the configured directory.
|
"""Load external API specifications from the configured directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from importlib.metadata import version
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
|
@ -22,7 +22,7 @@ from llama_stack_api import (
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectConfig(BaseModel):
|
class DistributionInspectConfig(BaseModel):
|
||||||
run_config: StackRunConfig
|
run_config: StackConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config, deps):
|
async def get_provider_impl(config, deps):
|
||||||
|
|
@ -40,7 +40,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackConfig = self.config.run_config
|
||||||
|
|
||||||
# Helper function to determine if a route should be included based on api_filter
|
# Helper function to determine if a route should be included based on api_filter
|
||||||
def should_include_route(webmethod) -> bool:
|
def should_include_route(webmethod) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||||
|
|
||||||
|
|
@ -20,7 +20,7 @@ class PromptServiceConfig(BaseModel):
|
||||||
:param run_config: Stack run configuration containing distribution info
|
:param run_config: Stack run configuration containing distribution info
|
||||||
"""
|
"""
|
||||||
|
|
||||||
run_config: StackRunConfig
|
run_config: StackConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||||
|
|
|
||||||
|
|
@ -12,14 +12,14 @@ from pydantic import BaseModel
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
|
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
|
||||||
|
|
||||||
from .datatypes import StackRunConfig
|
from .datatypes import StackConfig
|
||||||
from .utils.config import redact_sensitive_fields
|
from .utils.config import redact_sensitive_fields
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ProviderImplConfig(BaseModel):
|
class ProviderImplConfig(BaseModel):
|
||||||
run_config: StackRunConfig
|
run_config: StackConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config, deps):
|
async def get_provider_impl(config, deps):
|
||||||
|
|
@ -42,7 +42,7 @@ class ProviderImpl(Providers):
|
||||||
|
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
run_config = self.config.run_config
|
run_config = self.config.run_config
|
||||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||||
providers_health = await self.get_providers_health()
|
providers_health = await self.get_providers_health()
|
||||||
ret = []
|
ret = []
|
||||||
for api, providers in safe_config.providers.items():
|
for api, providers in safe_config.providers.items():
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.core.datatypes import (
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
Provider,
|
Provider,
|
||||||
RoutingTableProviderSpec,
|
RoutingTableProviderSpec,
|
||||||
StackRunConfig,
|
StackConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
|
|
@ -147,7 +147,7 @@ ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
|
|
@ -217,7 +217,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
|
||||||
|
|
||||||
|
|
||||||
def validate_and_prepare_providers(
|
def validate_and_prepare_providers(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
run_config: StackConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||||
) -> dict[str, dict[str, ProviderWithSpec]]:
|
) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
||||||
|
|
@ -261,7 +261,7 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
||||||
|
|
||||||
|
|
||||||
def sort_providers_by_deps(
|
def sort_providers_by_deps(
|
||||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackConfig
|
||||||
) -> list[tuple[str, ProviderWithSpec]]:
|
) -> list[tuple[str, ProviderWithSpec]]:
|
||||||
"""Sorts providers based on their dependencies."""
|
"""Sorts providers based on their dependencies."""
|
||||||
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
||||||
|
|
@ -278,7 +278,7 @@ async def instantiate_providers(
|
||||||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
||||||
router_apis: set[Api],
|
router_apis: set[Api],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackConfig,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
internal_impls: dict[Api, Any] | None = None,
|
internal_impls: dict[Api, Any] | None = None,
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
|
|
@ -357,7 +357,7 @@ async def instantiate_provider(
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackConfig,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.core.datatypes import (
|
||||||
AccessRule,
|
AccessRule,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.store import DistributionRegistry
|
from llama_stack.core.store import DistributionRegistry
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack_api import Api, RoutingTable
|
from llama_stack_api import Api, RoutingTable
|
||||||
|
|
@ -51,7 +51,7 @@ async def get_routing_table_impl(
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(
|
async def get_auto_router_impl(
|
||||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
|
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackConfig, policy: list[AccessRule]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from .datasets import DatasetIORouter
|
from .datasets import DatasetIORouter
|
||||||
from .eval_scoring import EvalRouter, ScoringRouter
|
from .eval_scoring import EvalRouter, ScoringRouter
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
AuthenticationRequiredError,
|
AuthenticationRequiredError,
|
||||||
StackRunConfig,
|
StackConfig,
|
||||||
process_cors_config,
|
process_cors_config,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
|
|
@ -149,7 +149,7 @@ class StackApp(FastAPI):
|
||||||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
def __init__(self, config: StackConfig, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.stack: Stack = Stack(config)
|
self.stack: Stack = Stack(config)
|
||||||
|
|
||||||
|
|
@ -385,7 +385,7 @@ def create_app() -> StackApp:
|
||||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
|
|
||||||
config = replace_env_vars(config_contents)
|
config = replace_env_vars(config_contents)
|
||||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
config = StackConfig(**cast_image_name_to_string(config))
|
||||||
|
|
||||||
_log_run_config(run_config=config)
|
_log_run_config(run_config=config)
|
||||||
|
|
||||||
|
|
@ -506,7 +506,7 @@ def create_app() -> StackApp:
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _log_run_config(run_config: StackRunConfig):
|
def _log_run_config(run_config: StackConfig):
|
||||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||||
logger.info("Run configuration:")
|
logger.info("Run configuration:")
|
||||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from typing import Any
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
from llama_stack.core.datatypes import Provider, SafetyConfig, StackConfig, VectorStoresConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||||
|
|
@ -108,7 +108,7 @@ REGISTRY_REFRESH_TASK = None
|
||||||
TEST_RECORDING_CONTEXT = None
|
TEST_RECORDING_CONTEXT = None
|
||||||
|
|
||||||
|
|
||||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
|
||||||
for rsrc, api, register_method, list_method in RESOURCES:
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
objects = getattr(run_config.registered_resources, rsrc)
|
objects = getattr(run_config.registered_resources, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
|
|
@ -341,7 +341,7 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
def add_internal_implementations(impls: dict[Api, Any], run_config: StackConfig) -> None:
|
||||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -373,7 +373,7 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
impls[Api.conversations] = conversations_impl
|
impls[Api.conversations] = conversations_impl
|
||||||
|
|
||||||
|
|
||||||
def _initialize_storage(run_config: StackRunConfig):
|
def _initialize_storage(run_config: StackConfig):
|
||||||
kv_backends: dict[str, StorageBackendConfig] = {}
|
kv_backends: dict[str, StorageBackendConfig] = {}
|
||||||
sql_backends: dict[str, StorageBackendConfig] = {}
|
sql_backends: dict[str, StorageBackendConfig] = {}
|
||||||
for backend_name, backend_config in run_config.storage.backends.items():
|
for backend_name, backend_config in run_config.storage.backends.items():
|
||||||
|
|
@ -393,7 +393,7 @@ def _initialize_storage(run_config: StackRunConfig):
|
||||||
|
|
||||||
|
|
||||||
class Stack:
|
class Stack:
|
||||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
def __init__(self, run_config: StackConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
self.run_config = run_config
|
self.run_config = run_config
|
||||||
self.provider_registry = provider_registry
|
self.provider_registry = provider_registry
|
||||||
self.impls = None
|
self.impls = None
|
||||||
|
|
@ -499,7 +499,7 @@ async def refresh_registry_task(impls: dict[Api, Any]):
|
||||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
def get_stack_run_config_from_distro(distro: str) -> StackConfig:
|
||||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
|
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
|
||||||
|
|
||||||
with importlib.resources.as_file(distro_path) as path:
|
with importlib.resources.as_file(distro_path) as path:
|
||||||
|
|
@ -507,12 +507,12 @@ def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
||||||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||||
run_config = yaml.safe_load(path.open())
|
run_config = yaml.safe_load(path.open())
|
||||||
|
|
||||||
return StackRunConfig(**replace_env_vars(run_config))
|
return StackConfig(**replace_env_vars(run_config))
|
||||||
|
|
||||||
|
|
||||||
def run_config_from_adhoc_config_spec(
|
def run_config_from_adhoc_config_spec(
|
||||||
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
||||||
) -> StackRunConfig:
|
) -> StackConfig:
|
||||||
"""
|
"""
|
||||||
Create an adhoc distribution from a list of API providers.
|
Create an adhoc distribution from a list of API providers.
|
||||||
|
|
||||||
|
|
@ -552,7 +552,7 @@ def run_config_from_adhoc_config_spec(
|
||||||
config=provider_config,
|
config=provider_config,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
config = StackRunConfig(
|
config = StackConfig(
|
||||||
image_name="distro-test",
|
image_name="distro-test",
|
||||||
apis=list(provider_configs_by_api.keys()),
|
apis=list(provider_configs_by_api.keys()),
|
||||||
providers=provider_configs_by_api,
|
providers=provider_configs_by_api,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
|
|
||||||
|
|
||||||
def get_test_configs():
|
def get_test_configs():
|
||||||
|
|
@ -49,4 +49,4 @@ def test_load_run_config(config_file):
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
config_data = yaml.safe_load(f)
|
config_data = yaml.safe_load(f)
|
||||||
|
|
||||||
StackRunConfig.model_validate(config_data)
|
StackConfig.model_validate(config_data)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.storage.datatypes import (
|
from llama_stack.core.storage.datatypes import (
|
||||||
PostgresKVStoreConfig,
|
PostgresKVStoreConfig,
|
||||||
PostgresSqlStoreConfig,
|
PostgresSqlStoreConfig,
|
||||||
|
|
@ -20,7 +20,7 @@ def test_starter_distribution_config_loads_and_resolves():
|
||||||
with open("llama_stack/distributions/starter/run.yaml") as f:
|
with open("llama_stack/distributions/starter/run.yaml") as f:
|
||||||
config_dict = yaml.safe_load(f)
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
config = StackRunConfig(**config_dict)
|
config = StackConfig(**config_dict)
|
||||||
|
|
||||||
# Config should have named backends and explicit store references
|
# Config should have named backends and explicit store references
|
||||||
assert config.storage is not None
|
assert config.storage is not None
|
||||||
|
|
@ -50,7 +50,7 @@ def test_postgres_demo_distribution_config_loads():
|
||||||
with open("llama_stack/distributions/postgres-demo/run.yaml") as f:
|
with open("llama_stack/distributions/postgres-demo/run.yaml") as f:
|
||||||
config_dict = yaml.safe_load(f)
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
config = StackRunConfig(**config_dict)
|
config = StackConfig(**config_dict)
|
||||||
|
|
||||||
# Should have postgres backend
|
# Should have postgres backend
|
||||||
assert config.storage is not None
|
assert config.storage is not None
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.core.conversations.conversations import (
|
||||||
ConversationServiceConfig,
|
ConversationServiceConfig,
|
||||||
ConversationServiceImpl,
|
ConversationServiceImpl,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.storage.datatypes import (
|
from llama_stack.core.storage.datatypes import (
|
||||||
ServerStoresConfig,
|
ServerStoresConfig,
|
||||||
SqliteSqlStoreConfig,
|
SqliteSqlStoreConfig,
|
||||||
|
|
@ -44,7 +44,7 @@ async def service():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
run_config = StackConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||||
|
|
||||||
config = ConversationServiceConfig(run_config=run_config, policy=[])
|
config = ConversationServiceConfig(run_config=run_config, policy=[])
|
||||||
service = ConversationServiceImpl(config, {})
|
service = ConversationServiceImpl(config, {})
|
||||||
|
|
@ -151,7 +151,7 @@ async def test_policy_configuration():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
run_config = StackConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||||
|
|
||||||
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
|
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
|
||||||
service = ConversationServiceImpl(config, {})
|
service = ConversationServiceImpl(config, {})
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
|
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackConfig, VectorStoresConfig
|
||||||
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
||||||
from llama_stack.core.storage.datatypes import ServerStoresConfig, StorageConfig
|
from llama_stack.core.storage.datatypes import ServerStoresConfig, StorageConfig
|
||||||
from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model, ModelType, Shield
|
from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model, ModelType, Shield
|
||||||
|
|
@ -19,7 +19,7 @@ from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model,
|
||||||
class TestVectorStoresValidation:
|
class TestVectorStoresValidation:
|
||||||
async def test_validate_missing_model(self):
|
async def test_validate_missing_model(self):
|
||||||
"""Test validation fails when model not found."""
|
"""Test validation fails when model not found."""
|
||||||
run_config = StackRunConfig(
|
run_config = StackConfig(
|
||||||
image_name="test",
|
image_name="test",
|
||||||
providers={},
|
providers={},
|
||||||
storage=StorageConfig(
|
storage=StorageConfig(
|
||||||
|
|
@ -47,7 +47,7 @@ class TestVectorStoresValidation:
|
||||||
|
|
||||||
async def test_validate_success(self):
|
async def test_validate_success(self):
|
||||||
"""Test validation passes with valid model."""
|
"""Test validation passes with valid model."""
|
||||||
run_config = StackRunConfig(
|
run_config = StackConfig(
|
||||||
image_name="test",
|
image_name="test",
|
||||||
providers={},
|
providers={},
|
||||||
storage=StorageConfig(
|
storage=StorageConfig(
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from pydantic import ValidationError
|
||||||
|
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
StackRunConfig,
|
StackConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.core.storage.datatypes import (
|
from llama_stack.core.storage.datatypes import (
|
||||||
InferenceStoreReference,
|
InferenceStoreReference,
|
||||||
|
|
@ -51,7 +51,7 @@ def _base_run_config(**overrides):
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return StackRunConfig(
|
return StackConfig(
|
||||||
version=LLAMA_STACK_RUN_CONFIG_VERSION,
|
version=LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
image_name="test-distro",
|
image_name="test-distro",
|
||||||
apis=[],
|
apis=[],
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackConfig
|
||||||
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
||||||
from llama_stack.core.storage.datatypes import (
|
from llama_stack.core.storage.datatypes import (
|
||||||
InferenceStoreReference,
|
InferenceStoreReference,
|
||||||
|
|
@ -53,7 +53,7 @@ def _default_storage() -> StorageConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_stack_config(**overrides) -> StackRunConfig:
|
def make_stack_config(**overrides) -> StackConfig:
|
||||||
storage = overrides.pop("storage", _default_storage())
|
storage = overrides.pop("storage", _default_storage())
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
image_name="test_image",
|
image_name="test_image",
|
||||||
|
|
@ -62,7 +62,7 @@ def make_stack_config(**overrides) -> StackRunConfig:
|
||||||
storage=storage,
|
storage=storage,
|
||||||
)
|
)
|
||||||
defaults.update(overrides)
|
defaults.update(overrides)
|
||||||
return StackRunConfig(**defaults)
|
return StackConfig(**defaults)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ async def temp_prompt_store(tmp_path_factory):
|
||||||
temp_dir = tmp_path_factory.getbasetemp()
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
db_path = str(temp_dir / f"{unique_id}.db")
|
db_path = str(temp_dir / f"{unique_id}.db")
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
|
|
||||||
storage = StorageConfig(
|
storage = StorageConfig(
|
||||||
backends={
|
backends={
|
||||||
|
|
@ -41,7 +41,7 @@ async def temp_prompt_store(tmp_path_factory):
|
||||||
prompts=KVStoreReference(backend="kv_test", namespace="prompts"),
|
prompts=KVStoreReference(backend="kv_test", namespace="prompts"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
mock_run_config = StackRunConfig(
|
mock_run_config = StackConfig(
|
||||||
image_name="test-distribution",
|
image_name="test-distribution",
|
||||||
apis=[],
|
apis=[],
|
||||||
providers={},
|
providers={},
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackConfig
|
||||||
from llama_stack.core.resolver import resolve_impls
|
from llama_stack.core.resolver import resolve_impls
|
||||||
from llama_stack.core.routers.inference import InferenceRouter
|
from llama_stack.core.routers.inference import InferenceRouter
|
||||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||||
|
|
@ -71,7 +71,7 @@ class SampleImpl:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def make_run_config(**overrides) -> StackRunConfig:
|
def make_run_config(**overrides) -> StackConfig:
|
||||||
storage = overrides.pop(
|
storage = overrides.pop(
|
||||||
"storage",
|
"storage",
|
||||||
StorageConfig(
|
StorageConfig(
|
||||||
|
|
@ -97,7 +97,7 @@ def make_run_config(**overrides) -> StackRunConfig:
|
||||||
storage=storage,
|
storage=storage,
|
||||||
)
|
)
|
||||||
defaults.update(overrides)
|
defaults.update(overrides)
|
||||||
return StackRunConfig(**defaults)
|
return StackConfig(**defaults)
|
||||||
|
|
||||||
|
|
||||||
async def test_resolve_impls_basic():
|
async def test_resolve_impls_basic():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue