diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index b573b2edc..3f94b1e2c 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None: config = parse_and_maybe_upgrade_config(config_dict) if config.external_providers_dir and not config.external_providers_dir.exists(): config.external_providers_dir.mkdir(exist_ok=True) - run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) - run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) + run_args = formulate_run_args(args.image_type, args.image_name) + run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)]) run_command(run_args) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f4a119522..3cb2e213c 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -82,39 +82,6 @@ class StackRun(Subcommand): return ImageType.CONDA.value, args.image_name return args.image_type, args.image_name - def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]: - """Resolve config file path and template name from args.config""" - from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - - if not args.config: - return None, None - - config_file = Path(args.config) - has_yaml_suffix = args.config.endswith(".yaml") - template_name = None - - if not config_file.exists() and not has_yaml_suffix: - # check if this is a template - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml" - if config_file.exists(): - template_name = args.config - - if not config_file.exists() and not has_yaml_suffix: - # check if it's a build config saved to ~/.llama dir - config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") - - if not config_file.exists(): - self.parser.error( - f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file" - ) - - if not config_file.is_file(): - self.parser.error( - f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}" - ) - - return config_file, template_name - def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml @@ -125,8 +92,15 @@ class StackRun(Subcommand): self._start_ui_development_server(args.port) image_type, image_name = self._get_image_type_and_name(args) - # Resolve config file and template name first - config_file, template_name = self._resolve_config_and_template(args) + if args.config: + try: + from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template + + config_file = resolve_config_or_template(args.config, Mode.RUN) + except ValueError as e: + self.parser.error(str(e)) + else: + config_file = None # Check if config is required based on image type if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file: @@ -164,18 +138,14 @@ class StackRun(Subcommand): if callable(getattr(args, arg)): continue if arg == "config": - if template_name: - server_args.template = str(template_name) - else: - # Set the config file path - server_args.config = str(config_file) + server_args.config = str(config_file) else: setattr(server_args, arg, getattr(args, arg)) # Run the server server_main(server_args) else: - run_args = formulate_run_args(image_type, image_name, config, template_name) + run_args = formulate_run_args(image_type, image_name) run_args.extend([str(args.port)]) diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py new file mode 100644 index 000000000..433627cc0 --- /dev/null +++ b/llama_stack/cli/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + + +def add_config_template_args(parser: argparse.ArgumentParser): + """Add unified config/template arguments with backward compatibility.""" + group = parser.add_mutually_exclusive_group(required=True) + + group.add_argument( + "config", + nargs="?", + help="Configuration file path or template name", + ) + + # Backward compatibility arguments (deprecated) + group.add_argument( + "--config", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", + ) + + group.add_argument( + "--template", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Template name", + ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ede65e8d6..e58c28f2e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -32,6 +32,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.cli.utils import add_config_template_args from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.datatypes import ( AuthenticationRequiredError, @@ -53,6 +54,7 @@ from llama_stack.distribution.stack import ( validate_env_pair, ) from llama_stack.distribution.utils.config import redact_sensitive_fields +from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -377,20 +379,8 @@ class ClientVersionMiddleware: def main(args: argparse.Namespace | None = None): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") - parser.add_argument( - "--yaml-config", - dest="config", - help="(Deprecated) Path to YAML configuration file - use --config instead", - ) - parser.add_argument( - "--config", - dest="config", - help="Path to YAML configuration file", - ) - parser.add_argument( - "--template", - help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)", - ) + + add_config_template_args(parser) parser.add_argument( "--port", type=int, @@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - log_line = "" - if hasattr(args, "config") and args.config: - # if the user provided a config file, use it, even if template was specified - config_file = Path(args.config) - if not config_file.exists(): - raise ValueError(f"Config file {config_file} does not exist") - log_line = f"Using config file: {config_file}" - elif hasattr(args, "template") and args.template: - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" - if not config_file.exists(): - raise ValueError(f"Template {args.template} does not exist") - log_line = f"Using template {args.template} config file: {config_file}" - else: - raise ValueError("Either --config or --template must be provided") + config_file = resolve_config_or_template(args.config, Mode.RUN) logger_config = None with open(config_file) as fp: @@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None): config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) - # now that the logger is initialized, print the line about which type of config we are using. - logger.info(log_line) - _log_run_config(run_config=config) app = FastAPI( diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index 85bfceec4..77a7dc92e 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then set -x if [ -n "$yaml_config" ]; then - yaml_config_arg="--config $yaml_config" + yaml_config_arg="$yaml_config" else yaml_config_arg="" fi diff --git a/llama_stack/distribution/utils/config_resolution.py b/llama_stack/distribution/utils/config_resolution.py new file mode 100644 index 000000000..7e8de1242 --- /dev/null +++ b/llama_stack/distribution/utils/config_resolution.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import StrEnum +from pathlib import Path + +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="config_resolution") + + +TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates" + + +class Mode(StrEnum): + RUN = "run" + BUILD = "build" + + +def resolve_config_or_template( + config_or_template: str, + mode: Mode = Mode.RUN, +) -> Path: + """ + Resolve a config/template argument to a concrete config file path. + + Args: + config_or_template: User input (file path, template name, or built distribution) + mode: Mode resolving for ("run", "build", "server") + + Returns: + Path to the resolved config file + + Raises: + ValueError: If resolution fails + """ + + # Strategy 1: Try as file path first + config_path = Path(config_or_template) + if config_path.exists() and config_path.is_file(): + logger.info(f"Using file path: {config_path}") + return config_path.resolve() + + # Strategy 2: Try as template name (if no .yaml extension) + if not config_or_template.endswith(".yaml"): + template_config = _get_template_config_path(config_or_template, mode) + if template_config.exists(): + logger.info(f"Using template: {template_config}") + return template_config + + # Strategy 3: Try as built distribution name + distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + # Strategy 4: Failed - provide helpful error + raise ValueError(_format_resolution_error(config_or_template, mode)) + + +def _get_template_config_path(template_name: str, mode: Mode) -> Path: + """Get the config file path for a template.""" + return TEMPLATE_DIR / template_name / f"{mode}.yaml" + + +def _format_resolution_error(config_or_template: str, mode: Mode) -> str: + """Format a helpful error message for resolution failures.""" + from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + + template_path = _get_template_config_path(config_or_template, mode) + distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + + available_templates = _get_available_templates() + templates_str = ", ".join(available_templates) if available_templates else "none found" + + return f"""Could not resolve config or template '{config_or_template}'. + +Tried the following locations: + 1. As file path: {Path(config_or_template).resolve()} + 2. As template: {template_path} + 3. As built distribution: ({distrib_path}, {distrib_path2}) + +Available templates: {templates_str} + +Did you mean one of these templates? +{_format_template_suggestions(available_templates, config_or_template)} +""" + + +def _get_available_templates() -> list[str]: + """Get list of available template names.""" + if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists(): + return [] + + return list( + set( + [d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + + [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + ) + ) + + +def _format_template_suggestions(templates: list[str], user_input: str) -> str: + """Format template suggestions for error messages, showing closest matches first.""" + if not templates: + return " (no templates found)" + + import difflib + + # Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive) + close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3) + display_templates = close_matches if close_matches else templates[:3] + + suggestions = [f" - {t}" for t in display_templates] + return "\n".join(suggestions) diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 2db01689f..c646ae821 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -21,7 +21,7 @@ from pathlib import Path from llama_stack.distribution.utils.image_types import LlamaStackImageType -def formulate_run_args(image_type, image_name, config, template_name) -> list: +def formulate_run_args(image_type: str, image_name: str) -> list[str]: env_name = "" if image_type == LlamaStackImageType.CONDA.value: