From d99c06fce87325947704d6fd8bd2deab72ab76b4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 30 Aug 2024 15:03:23 -0700 Subject: [PATCH] Fix stack start --- llama_toolchain/cli/stack/start.py | 33 ++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/llama_toolchain/cli/stack/start.py b/llama_toolchain/cli/stack/start.py index 433d5b8dc..d090bdf6a 100644 --- a/llama_toolchain/cli/stack/start.py +++ b/llama_toolchain/cli/stack/start.py @@ -13,6 +13,7 @@ import yaml from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR class StackStart(Subcommand): @@ -28,10 +29,23 @@ class StackStart(Subcommand): self.parser.set_defaults(func=self._run_stack_start_cmd) def _add_arguments(self): + from llama_toolchain.core.package import BuildType + self.parser.add_argument( - "yaml_config", + "distribution", type=str, - help="Yaml config containing the API build configuration", + help="Distribution whose build you want to start", + ) + self.parser.add_argument( + "--build-name", + type=str, + help="Name of the API build you want to start", + ) + self.parser.add_argument( + "--build-type", + type=str, + default="conda_env", + choices=[v.value for v in BuildType], ) self.parser.add_argument( "--port", @@ -48,13 +62,16 @@ class StackStart(Subcommand): def _run_stack_start_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty + from llama_toolchain.core.package import BuildType - config_file = Path(args.yaml_config) - if not config_file.exists(): - self.parser.error( - f"Could not find {config_file}. Please run `llama stack build` first" - ) - return + if args.build_name.endswith(".yaml"): + path = args.build_name + else: + build_type = BuildType(args.build_type) + build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor() + path = build_dir / f"{args.build_name}.yaml" + + config_file = Path(path) with open(config_file, "r") as f: config = PackageConfig(**yaml.safe_load(f))