From 987e1cafc461261b856ceb1e3bb72fad5587bc0a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 11 Sep 2024 10:43:36 -0700 Subject: [PATCH] only consume config as argument --- llama_toolchain/cli/stack/build.py | 45 +++++-------------------- llama_toolchain/cli/stack/configure.py | 18 +++++----- llama_toolchain/cli/stack/run.py | 34 +++++-------------- llama_toolchain/core/build_conda_env.sh | 2 +- llama_toolchain/core/datatypes.py | 6 ++-- llama_toolchain/core/package.py | 25 ++++++-------- 6 files changed, 39 insertions(+), 91 deletions(-) diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index b8ce1bf39..d2d7df6d0 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -8,7 +8,6 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.core.datatypes import * # noqa: F403 - import yaml @@ -52,33 +51,9 @@ class StackBuild(Subcommand): from llama_toolchain.core.distribution_registry import ( available_distribution_specs, ) - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType allowed_ids = [d.distribution_type for d in available_distribution_specs()] - self.parser.add_argument( - "--distribution", - type=str, - help='Distribution to build (either "adhoc" OR one of: {})'.format( - allowed_ids - ), - ) - self.parser.add_argument( - "--api-providers", - nargs="?", - help="Comma separated list of (api=provider) tuples", - ) - - self.parser.add_argument( - "--name", - type=str, - help="Name of the build target (image, conda env)", - ) - self.parser.add_argument( - "--package-type", - type=str, - default="conda_env", - choices=[v.value for v in BuildType], - ) self.parser.add_argument( "--config", type=str, @@ -94,7 +69,7 @@ class StackBuild(Subcommand): from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.core.distribution_registry import resolve_distribution_spec - from llama_toolchain.core.package import ApiInput, build_package, BuildType + from llama_toolchain.core.package import ApiInput, build_package, ImageType from termcolor import cprint api_inputs = [] @@ -146,7 +121,7 @@ class StackBuild(Subcommand): build_package( api_inputs, - build_type=BuildType(build_config.package_type), + image_type=ImageType(build_config.image_type), name=build_config.name, distribution_type=build_config.distribution, docker_image=docker_image, @@ -154,9 +129,7 @@ class StackBuild(Subcommand): # save build.yaml spec for building same distribution again build_dir = ( - DISTRIBS_BASE_DIR - / build_config.distribution - / BuildType(build_config.package_type).descriptor() + DISTRIBS_BASE_DIR / build_config.distribution / build_config.image_type ) os.makedirs(build_dir, exist_ok=True) build_file_path = build_dir / f"{build_config.name}-build.yaml" @@ -171,6 +144,9 @@ class StackBuild(Subcommand): ) def _run_stack_build_command(self, args: argparse.Namespace) -> None: + from llama_toolchain.common.prompt_for_config import prompt_for_config + from llama_toolchain.core.dynamic import instantiate_class_type + if args.config: with open(args.config, "r") as f: try: @@ -181,10 +157,5 @@ class StackBuild(Subcommand): self._run_stack_build_command_from_build_config(build_config) return - build_config = BuildConfig( - name=args.name, - distribution=args.distribution, - package_type=args.package_type, - api_providers=args.api_providers, - ) + build_config = prompt_for_config(BuildConfig, None) self._run_stack_build_command_from_build_config(build_config) diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 667177601..c85b542e1 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -34,7 +34,7 @@ class StackConfigure(Subcommand): from llama_toolchain.core.distribution_registry import ( available_distribution_specs, ) - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType allowed_ids = [d.distribution_type for d in available_distribution_specs()] self.parser.add_argument( @@ -48,10 +48,10 @@ class StackConfigure(Subcommand): help="Name of the build", ) self.parser.add_argument( - "--package-type", + "--image-type", type=str, - default="conda_env", - choices=[v.value for v in BuildType], + default="conda", + choices=[v.value for v in ImageType], ) self.parser.add_argument( "--config", @@ -60,22 +60,20 @@ class StackConfigure(Subcommand): ) def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType if args.config: with open(args.config, "r") as f: build_config = BuildConfig(**yaml.safe_load(f)) - build_type = BuildType(build_config.package_type) + image_type = ImageType(build_config.image_type) distribution = build_config.distribution name = build_config.name else: - build_type = BuildType(args.package_type) + image_type = ImageType(args.image_type) name = args.name distribution = args.distribution - config_file = ( - BUILDS_BASE_DIR / distribution / build_type.descriptor() / f"{name}.yaml" - ) + config_file = BUILDS_BASE_DIR / distribution / image_type.value / f"{name}.yaml" if not config_file.exists(): self.parser.error( f"Could not find {config_file}. Please run `llama stack build` first" diff --git a/llama_toolchain/cli/stack/run.py b/llama_toolchain/cli/stack/run.py index 118cb0665..d040cb1f7 100644 --- a/llama_toolchain/cli/stack/run.py +++ b/llama_toolchain/cli/stack/run.py @@ -29,23 +29,12 @@ class StackRun(Subcommand): self.parser.set_defaults(func=self._run_stack_run_cmd) def _add_arguments(self): - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType self.parser.add_argument( - "--distribution", + "config", type=str, - help="Distribution whose build you want to start", - ) - self.parser.add_argument( - "--name", - type=str, - help="Name of the build you want to start", - ) - self.parser.add_argument( - "--type", - type=str, - default="conda_env", - choices=[v.value for v in BuildType], + help="Path to config file to use for the run", ) self.parser.add_argument( "--port", @@ -59,23 +48,16 @@ class StackRun(Subcommand): help="Disable IPv6 support", default=False, ) - self.parser.add_argument( - "--config", - type=str, - help="Path to config file to use for the run", - ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType - if args.config: - path = args.config - else: - build_type = BuildType(args.type) - build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor() - path = build_dir / f"{args.name}.yaml" + if not args.config: + self.parser.error("Must specify a config file to run") + return + path = args.config config_file = Path(path) if not config_file.exists(): diff --git a/llama_toolchain/core/build_conda_env.sh b/llama_toolchain/core/build_conda_env.sh index 866ca3b94..2e52e48e2 100755 --- a/llama_toolchain/core/build_conda_env.sh +++ b/llama_toolchain/core/build_conda_env.sh @@ -117,4 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies" printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure --distribution $distribution_type --name "$build_name" --package-type conda_env +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure --distribution $distribution_type --name "$build_name" --image-type conda diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index e927cee1d..2405a57ce 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -200,7 +200,7 @@ class BuildConfig(BaseModel): default_factory=list, description="List of API provider names to build", ) - package_type: str = Field( - default="conda_env", - description="Type of package to build (conda_env | container)", + image_type: str = Field( + default="conda", + description="Type of package to build (conda | container)", ) diff --git a/llama_toolchain/core/package.py b/llama_toolchain/core/package.py index ab4346a71..0af75c3a6 100644 --- a/llama_toolchain/core/package.py +++ b/llama_toolchain/core/package.py @@ -12,24 +12,21 @@ from typing import List, Optional import pkg_resources import yaml -from pydantic import BaseModel - -from termcolor import cprint from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_toolchain.common.exec import run_with_pty from llama_toolchain.common.serialize import EnumEncoder +from pydantic import BaseModel + +from termcolor import cprint from llama_toolchain.core.datatypes import * # noqa: F403 from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES -class BuildType(Enum): - container = "container" - conda_env = "conda_env" - - def descriptor(self) -> str: - return "docker" if self == self.container else "conda" +class ImageType(Enum): + docker = "docker" + conda = "conda" class Dependencies(BaseModel): @@ -44,7 +41,7 @@ class ApiInput(BaseModel): def build_package( api_inputs: List[ApiInput], - build_type: BuildType, + image_type: ImageType, name: str, distribution_type: Optional[str] = None, docker_image: Optional[str] = None, @@ -52,7 +49,7 @@ def build_package( if not distribution_type: distribution_type = "adhoc" - build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor() + build_dir = BUILDS_BASE_DIR / distribution_type / image_type.value os.makedirs(build_dir, exist_ok=True) package_name = name.replace("::", "-") @@ -106,14 +103,14 @@ def build_package( ) c.distribution_type = distribution_type - c.docker_image = package_name if build_type == BuildType.container else None - c.conda_env = package_name if build_type == BuildType.conda_env else None + c.docker_image = package_name if image_type == ImageType.docker else None + c.conda_env = package_name if image_type == ImageType.conda else None with open(package_file, "w") as f: to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder)) f.write(yaml.dump(to_write, sort_keys=False)) - if build_type == BuildType.container: + if image_type == ImageType.docker: script = pkg_resources.resource_filename( "llama_toolchain", "core/build_container.sh" )