mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix: rename StackRunConfig to StackConfig
since this object represents our config for list-deps, run, etc lets rename it to simply `StackConfig` Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
17f8ab31b5
commit
4a3f9151e3
23 changed files with 72 additions and 72 deletions
|
|
@ -12,7 +12,7 @@ import yaml
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.build import get_provider_dependencies
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.datatypes import Provider, StackConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api
|
||||
|
|
@ -78,7 +78,7 @@ def run_stack_list_deps_command(args: argparse.Namespace) -> None:
|
|||
with open(config_file) as f:
|
||||
try:
|
||||
contents = yaml.safe_load(f)
|
||||
run_config = StackRunConfig(**contents)
|
||||
run_config = StackConfig(**contents)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Could not parse config file {config_file}: {e}",
|
||||
|
|
@ -119,7 +119,7 @@ def run_stack_list_deps_command(args: argparse.Namespace) -> None:
|
|||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
run_config = StackRunConfig(providers=provider_list, image_name="providers-run")
|
||||
run_config = StackConfig(providers=provider_list, image_name="providers-run")
|
||||
|
||||
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(run_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.datatypes import Api, Provider, StackConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
|
|
@ -156,7 +156,7 @@ class StackRun(Subcommand):
|
|||
|
||||
# Write config to disk in providers-run directory
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
config_file = distro_dir / "run.yaml"
|
||||
config_file = distro_dir / "config.yaml"
|
||||
|
||||
logger.info(f"Writing generated config to: {config_file}")
|
||||
with open(config_file, "w") as f:
|
||||
|
|
@ -194,7 +194,7 @@ class StackRun(Subcommand):
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
config = StackConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or "0.0.0.0"
|
||||
|
|
@ -318,7 +318,7 @@ class StackRun(Subcommand):
|
|||
),
|
||||
)
|
||||
|
||||
return StackRunConfig(
|
||||
return StackConfig(
|
||||
image_name="providers-run",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from termcolor import cprint
|
|||
from llama_stack.core.datatypes import (
|
||||
BuildConfig,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
StackConfig,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
|
|
@ -61,7 +61,7 @@ def generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
distro_dir = DISTRIBS_BASE_DIR / image_name
|
||||
run_config = StackRunConfig(
|
||||
run_config = StackConfig(
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import sys
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.datatypes import StackConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -36,7 +36,7 @@ class ApiInput(BaseModel):
|
|||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config: StackRunConfig,
|
||||
config: StackConfig,
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
if isinstance(config, DistributionTemplate):
|
||||
|
|
@ -83,7 +83,7 @@ def get_provider_dependencies(
|
|||
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(config: StackRunConfig):
|
||||
def print_pip_install_help(config: StackConfig):
|
||||
normal_deps, special_deps, _ = get_provider_dependencies(config)
|
||||
|
||||
cprint(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from llama_stack.core.datatypes import (
|
|||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
StackConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
|
|
@ -44,7 +44,7 @@ def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provi
|
|||
)
|
||||
|
||||
|
||||
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
||||
def configure_api_providers(config: StackConfig, build_spec: DistributionSpec) -> StackConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
if is_nux:
|
||||
|
|
@ -192,7 +192,7 @@ def upgrade_from_routing_table(
|
|||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackConfig:
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
|
@ -200,4 +200,4 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
return StackConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any, Literal
|
|||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.core.datatypes import AccessRule, StackConfig
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -36,7 +36,7 @@ class ConversationServiceConfig(BaseModel):
|
|||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
run_config: StackConfig
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -490,7 +490,7 @@ class ServerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
class StackConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
image_name: str = Field(
|
||||
|
|
@ -565,7 +565,7 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_server_stores(self) -> "StackRunConfig":
|
||||
def validate_server_stores(self) -> "StackConfig":
|
||||
backend_map = self.storage.backends
|
||||
stores = self.storage.stores
|
||||
kv_backends = {
|
||||
|
|
|
|||
|
|
@ -7,14 +7,14 @@
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.core.datatypes import BuildConfig, StackConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api, ExternalApiSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
def load_external_apis(config: StackConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
"""Load external API specifications from the configured directory.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from importlib.metadata import version
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.datatypes import StackConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
|
|
@ -22,7 +22,7 @@ from llama_stack_api import (
|
|||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
run_config: StackConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
|
|
@ -40,7 +40,7 @@ class DistributionInspectImpl(Inspect):
|
|||
pass
|
||||
|
||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||
run_config: StackRunConfig = self.config.run_config
|
||||
run_config: StackConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.datatypes import StackConfig
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
|
|
@ -20,7 +20,7 @@ class PromptServiceConfig(BaseModel):
|
|||
:param run_config: Stack run configuration containing distribution info
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
run_config: StackConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ from pydantic import BaseModel
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .datatypes import StackConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
run_config: StackConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
|
|
@ -42,7 +42,7 @@ class ProviderImpl(Providers):
|
|||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from llama_stack.core.datatypes import (
|
|||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
RoutingTableProviderSpec,
|
||||
StackRunConfig,
|
||||
StackConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
|
|
@ -147,7 +147,7 @@ ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
|||
|
||||
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
run_config: StackConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
|
|
@ -217,7 +217,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
|
|||
|
||||
|
||||
def validate_and_prepare_providers(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||
run_config: StackConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||
) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
||||
|
|
@ -261,7 +261,7 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
|
||||
def sort_providers_by_deps(
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackConfig
|
||||
) -> list[tuple[str, ProviderWithSpec]]:
|
||||
"""Sorts providers based on their dependencies."""
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
|
|
@ -278,7 +278,7 @@ async def instantiate_providers(
|
|||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
||||
router_apis: set[Api],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
run_config: StackConfig,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
|
|
@ -357,7 +357,7 @@ async def instantiate_provider(
|
|||
deps: dict[Api, Any],
|
||||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
run_config: StackConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from llama_stack.core.datatypes import (
|
|||
AccessRule,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.core.stack import StackRunConfig
|
||||
from llama_stack.core.datatypes import StackConfig
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import Api, RoutingTable
|
||||
|
|
@ -51,7 +51,7 @@ async def get_routing_table_impl(
|
|||
|
||||
|
||||
async def get_auto_router_impl(
|
||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
|
||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackConfig, policy: list[AccessRule]
|
||||
) -> Any:
|
||||
from .datasets import DatasetIORouter
|
||||
from .eval_scoring import EvalRouter, ScoringRouter
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
StackRunConfig,
|
||||
StackConfig,
|
||||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
|
|
@ -149,7 +149,7 @@ class StackApp(FastAPI):
|
|||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||
def __init__(self, config: StackConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stack: Stack = Stack(config)
|
||||
|
||||
|
|
@ -385,7 +385,7 @@ def create_app() -> StackApp:
|
|||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
config = StackConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
|
|
@ -506,7 +506,7 @@ def create_app() -> StackApp:
|
|||
return app
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
def _log_run_config(run_config: StackConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
|
|
@ -108,7 +108,7 @@ REGISTRY_REFRESH_TASK = None
|
|||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config.registered_resources, rsrc)
|
||||
if api not in impls:
|
||||
|
|
@ -341,7 +341,7 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|||
return config_dict
|
||||
|
||||
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
|
|
@ -373,7 +373,7 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
def _initialize_storage(run_config: StackRunConfig):
|
||||
def _initialize_storage(run_config: StackConfig):
|
||||
kv_backends: dict[str, StorageBackendConfig] = {}
|
||||
sql_backends: dict[str, StorageBackendConfig] = {}
|
||||
for backend_name, backend_config in run_config.storage.backends.items():
|
||||
|
|
@ -393,7 +393,7 @@ def _initialize_storage(run_config: StackRunConfig):
|
|||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
def __init__(self, run_config: StackConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
self.provider_registry = provider_registry
|
||||
self.impls = None
|
||||
|
|
@ -499,7 +499,7 @@ async def refresh_registry_task(impls: dict[Api, Any]):
|
|||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
||||
def get_stack_run_config_from_distro(distro: str) -> StackConfig:
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
|
|
@ -507,12 +507,12 @@ def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
|||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
return StackConfig(**replace_env_vars(run_config))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
||||
) -> StackRunConfig:
|
||||
) -> StackConfig:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
||||
|
|
@ -552,7 +552,7 @@ def run_config_from_adhoc_config_spec(
|
|||
config=provider_config,
|
||||
)
|
||||
]
|
||||
config = StackRunConfig(
|
||||
config = StackConfig(
|
||||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
providers=provider_configs_by_api,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue