From 3063329dad463562e59e4dd7d0ce7fd7ec2ca2be Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 28 Aug 2024 17:17:46 -0700 Subject: [PATCH] Some quick fixes to the CLI behavior to make it consistent --- llama_toolchain/cli/api/build.py | 2 +- llama_toolchain/cli/api/configure.py | 39 +++++++++++++++++++---- llama_toolchain/cli/stack/build.py | 3 +- llama_toolchain/cli/stack/configure.py | 41 +++++++++++++++++++++---- llama_toolchain/distribution/package.py | 6 ++-- 5 files changed, 74 insertions(+), 17 deletions(-) diff --git a/llama_toolchain/cli/api/build.py b/llama_toolchain/cli/api/build.py index 05f59c19e..ae7e815be 100644 --- a/llama_toolchain/cli/api/build.py +++ b/llama_toolchain/cli/api/build.py @@ -81,7 +81,7 @@ class ApiBuild(Subcommand): self.parser.add_argument( "--type", type=str, - default="container", + default="conda_env", choices=[v.value for v in BuildType], ) diff --git a/llama_toolchain/cli/api/configure.py b/llama_toolchain/cli/api/configure.py index ef48f175a..8b27ab1b1 100644 --- a/llama_toolchain/cli/api/configure.py +++ b/llama_toolchain/cli/api/configure.py @@ -32,6 +32,7 @@ class ApiConfigure(Subcommand): def _add_arguments(self): from llama_toolchain.distribution.distribution import stack_apis + from llama_toolchain.distribution.package import BuildType allowed_args = [a.name for a in stack_apis()] self.parser.add_argument( @@ -42,15 +43,41 @@ class ApiConfigure(Subcommand): self.parser.add_argument( "--build-name", type=str, - help="Name of the provider build to fully configure", - required=True, + help="(Fully qualified) name of the API build to configure. Alternatively, specify the --provider and --name options.", + required=False, + ) + + self.parser.add_argument( + "--provider", + type=str, + help="The provider chosen for the API", + required=False, + ) + self.parser.add_argument( + "--name", + type=str, + help="Name of the build target (image, conda env)", + required=False, + ) + self.parser.add_argument( + "--type", + type=str, + default="conda_env", + choices=[v.value for v in BuildType], ) def _run_api_configure_cmd(self, args: argparse.Namespace) -> None: - name = args.build_name - if not name.endswith(".yaml"): - name += ".yaml" - config_file = BUILDS_BASE_DIR / args.api / name + from llama_toolchain.distribution.package import BuildType + + if args.build_name: + name = args.build_name + if name.endswith(".yaml"): + name = name.replace(".yaml", "") + else: + build_type = BuildType(args.type) + name = f"{build_type.descriptor()}-{args.provider}-{args.name}" + + config_file = BUILDS_BASE_DIR / args.api / f"{name}.yaml" if not config_file.exists(): self.parser.error( f"Could not find {config_file}. Please run `llama api build` first" diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index ef2393c09..728275fe1 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import argparse -from typing import Dict from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.distribution.datatypes import * # noqa: F403 @@ -46,7 +45,7 @@ class StackBuild(Subcommand): self.parser.add_argument( "--type", type=str, - default="container", + default="conda_env", choices=[v.value for v in BuildType], ) diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 2e62238c2..80937be95 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -31,19 +31,48 @@ class StackConfigure(Subcommand): self.parser.set_defaults(func=self._run_stack_configure_cmd) def _add_arguments(self): + from llama_toolchain.distribution.package import BuildType + from llama_toolchain.distribution.registry import available_distribution_specs + self.parser.add_argument( "--build-name", type=str, - help="Name of the stack build to configure", - required=True, + help="(Fully qualified) name of the stack build to configure. Alternatively, provider --distribution and --name", + required=False, + ) + allowed_ids = [d.distribution_id for d in available_distribution_specs()] + self.parser.add_argument( + "--distribution", + type=str, + choices=allowed_ids, + help="Distribution (one of: {})".format(allowed_ids), + required=False, + ) + self.parser.add_argument( + "--name", + type=str, + help="Name of the build", + required=False, + ) + self.parser.add_argument( + "--type", + type=str, + default="conda_env", + choices=[v.value for v in BuildType], ) def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: - name = args.build_name - if not name.endswith(".yaml"): - name += ".yaml" + from llama_toolchain.distribution.package import BuildType - config_file = BUILDS_BASE_DIR / "stack" / name + if args.build_name: + name = args.build_name + if name.endswith(".yaml"): + name = name.replace(".yaml", "") + else: + build_type = BuildType(args.type) + name = f"{build_type.descriptor()}-{args.distribution}-{args.name}" + + config_file = BUILDS_BASE_DIR / "stack" / f"{name}.yaml" if not config_file.exists(): self.parser.error( f"Could not find {config_file}. Please run `llama stack build` first" diff --git a/llama_toolchain/distribution/package.py b/llama_toolchain/distribution/package.py index 7b4cf56ca..e0ee58a85 100644 --- a/llama_toolchain/distribution/package.py +++ b/llama_toolchain/distribution/package.py @@ -28,6 +28,9 @@ class BuildType(Enum): container = "container" conda_env = "conda_env" + def descriptor(self) -> str: + return "image" if self == self.container else "env" + class Dependencies(BaseModel): pip_packages: List[str] @@ -77,12 +80,11 @@ def build_package( provider = distribution_id if is_stack else api1.provider api_or_stack = "stack" if is_stack else api1.api.value - build_desc = "image" if build_type == BuildType.container else "env" build_dir = BUILDS_BASE_DIR / api_or_stack os.makedirs(build_dir, exist_ok=True) - package_name = f"{build_desc}-{provider}-{name}" + package_name = f"{build_type.descriptor()}-{provider}-{name}" package_name = package_name.replace("::", "-") package_file = build_dir / f"{package_name}.yaml"