diff --git a/docs/cli_reference.md b/docs/cli_reference.md index aaac83c70..dab2ee076 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -318,7 +318,7 @@ Build spec configuration saved at /home/xiyan/.llama/distributions/local/conda/8 You can re-build package based on build config ``` -$ llama stack build --config-file ~/.llama/distributions/local/conda/8b-instruct-build.yaml +$ llama stack build --config ~/.llama/distributions/local/conda/8b-instruct-build.yaml Successfully setup conda environment. Configuring build... @@ -334,7 +334,7 @@ Build spec configuration saved at /home/xiyan/.llama/distributions/local/conda/8 You can re-configure this distribution by running: ``` -llama stack configure --config-file ~/.llama/distributions/local/conda/8b-instruct-build.yaml +llama stack configure --config ~/.llama/distributions/local/conda/8b-instruct-build.yaml ``` or @@ -386,12 +386,12 @@ Now let’s start Llama Stack Distribution Server. You need the YAML configuration file which was written out at the end by the `llama stack build` step. ``` -llama stack run --run-config ~/.llama/builds/local/conda/8b-instruct.yaml --port 5000 +llama stack run --config ~/.llama/builds/local/conda/8b-instruct.yaml --port 5000 ``` You should see the Stack server start and print the APIs that it is supporting, ``` -$ llama stack run --run-config ~/.llama/builds/local/conda/8b-instruct.yaml --port 5000 +$ llama stack run --config ~/.llama/builds/local/conda/8b-instruct.yaml --port 5000 > initializing model parallel with size 1 > initializing ddp with size 1 diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index 4eb83b6cc..5fbaf4a1f 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -85,7 +85,7 @@ class StackBuild(Subcommand): choices=[v.value for v in BuildType], ) self.parser.add_argument( - "--config-file", + "--config", type=str, help="Path to a config file to use for the build", ) @@ -170,14 +170,12 @@ class StackBuild(Subcommand): ) def _run_stack_build_command(self, args: argparse.Namespace) -> None: - if args.config_file: - with open(args.config_file, "r") as f: + if args.config: + with open(args.config, "r") as f: try: build_config = BuildConfig(**yaml.safe_load(f)) except Exception as e: - self.parser.error( - f"Could not parse config file {args.config_file}: {e}" - ) + self.parser.error(f"Could not parse config file {args.config}: {e}") return self._run_stack_build_command_from_build_config(build_config) return diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index f90f3ba0f..667177601 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -54,7 +54,7 @@ class StackConfigure(Subcommand): choices=[v.value for v in BuildType], ) self.parser.add_argument( - "--config-file", + "--config", type=str, help="Path to a config file to use for the build", ) @@ -62,8 +62,8 @@ class StackConfigure(Subcommand): def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.core.package import BuildType - if args.config_file: - with open(args.config_file, "r") as f: + if args.config: + with open(args.config, "r") as f: build_config = BuildConfig(**yaml.safe_load(f)) build_type = BuildType(build_config.package_type) distribution = build_config.distribution diff --git a/llama_toolchain/cli/stack/run.py b/llama_toolchain/cli/stack/run.py index 8a195a3f3..118cb0665 100644 --- a/llama_toolchain/cli/stack/run.py +++ b/llama_toolchain/cli/stack/run.py @@ -60,7 +60,7 @@ class StackRun(Subcommand): default=False, ) self.parser.add_argument( - "--run-config", + "--config", type=str, help="Path to config file to use for the run", ) @@ -69,8 +69,8 @@ class StackRun(Subcommand): from llama_toolchain.common.exec import run_with_pty from llama_toolchain.core.package import BuildType - if args.run_config: - path = args.run_config + if args.config: + path = args.config else: build_type = BuildType(args.type) build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor()