fix run-config/config-file to config

This commit is contained in:
Xi Yan 2024-09-10 12:21:51 -07:00
parent ace3953926
commit 6c97e84372
4 changed files with 14 additions and 16 deletions

View file

@ -85,7 +85,7 @@ class StackBuild(Subcommand):
choices=[v.value for v in BuildType],
)
self.parser.add_argument(
"--config-file",
"--config",
type=str,
help="Path to a config file to use for the build",
)
@ -170,14 +170,12 @@ class StackBuild(Subcommand):
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
if args.config_file:
with open(args.config_file, "r") as f:
if args.config:
with open(args.config, "r") as f:
try:
build_config = BuildConfig(**yaml.safe_load(f))
except Exception as e:
self.parser.error(
f"Could not parse config file {args.config_file}: {e}"
)
self.parser.error(f"Could not parse config file {args.config}: {e}")
return
self._run_stack_build_command_from_build_config(build_config)
return

View file

@ -54,7 +54,7 @@ class StackConfigure(Subcommand):
choices=[v.value for v in BuildType],
)
self.parser.add_argument(
"--config-file",
"--config",
type=str,
help="Path to a config file to use for the build",
)
@ -62,8 +62,8 @@ class StackConfigure(Subcommand):
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.package import BuildType
if args.config_file:
with open(args.config_file, "r") as f:
if args.config:
with open(args.config, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
build_type = BuildType(build_config.package_type)
distribution = build_config.distribution

View file

@ -60,7 +60,7 @@ class StackRun(Subcommand):
default=False,
)
self.parser.add_argument(
"--run-config",
"--config",
type=str,
help="Path to config file to use for the run",
)
@ -69,8 +69,8 @@ class StackRun(Subcommand):
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.core.package import BuildType
if args.run_config:
path = args.run_config
if args.config:
path = args.config
else:
build_type = BuildType(args.type)
build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor()