From 0981193d7888ecef1fc8627b452abc24572d539a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 10 Sep 2024 11:02:46 -0700 Subject: [PATCH] config file for build --- llama_toolchain/cli/stack/build.py | 103 ++++++++++++++++++++++------- llama_toolchain/core/datatypes.py | 16 +++++ 2 files changed, 94 insertions(+), 25 deletions(-) diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index 22bd4071f..259aca3f4 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -8,6 +8,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR def parse_api_provider_tuples( @@ -47,20 +48,22 @@ class StackBuild(Subcommand): self.parser.set_defaults(func=self._run_stack_build_command) def _add_arguments(self): - from llama_toolchain.core.distribution_registry import available_distribution_specs - from llama_toolchain.core.package import ( - BuildType, + from llama_toolchain.core.distribution_registry import ( + available_distribution_specs, ) + from llama_toolchain.core.package import BuildType allowed_ids = [d.distribution_type for d in available_distribution_specs()] self.parser.add_argument( "distribution", type=str, - help="Distribution to build (either \"adhoc\" OR one of: {})".format(allowed_ids), + help='Distribution to build (either "adhoc" OR one of: {})'.format( + allowed_ids + ), ) self.parser.add_argument( "api_providers", - nargs='?', + nargs="?", help="Comma separated list of (api=provider) tuples", ) @@ -71,31 +74,38 @@ class StackBuild(Subcommand): required=True, ) self.parser.add_argument( - "--type", + "--package-type", type=str, default="conda_env", choices=[v.value for v in BuildType], ) - - def _run_stack_build_command(self, args: argparse.Namespace) -> None: - from llama_toolchain.core.distribution_registry import resolve_distribution_spec - from llama_toolchain.core.package import ( - ApiInput, - BuildType, - build_package, + self.parser.add_argument( + "--config-file", + type=str, + help="Path to a config file to use for the build", ) + def _run_stack_build_command_from_build_config( + self, build_config: BuildConfig + ) -> None: + from llama_toolchain.core.distribution_registry import resolve_distribution_spec + from llama_toolchain.core.package import ApiInput, build_package, BuildType + api_inputs = [] - if args.distribution == "adhoc": - if not args.api_providers: - self.parser.error("You must specify API providers with (api=provider,...) for building an adhoc distribution") + if build_config.distribution == "adhoc": + if not build_config.api_providers: + self.parser.error( + "You must specify API providers with (api=provider,...) for building an adhoc distribution" + ) return - parsed = parse_api_provider_tuples(args.api_providers, self.parser) + parsed = parse_api_provider_tuples(build_config.api_providers, self.parser) for api, provider_spec in parsed.items(): for dep in provider_spec.api_dependencies: if dep not in parsed: - self.parser.error(f"API {api} needs dependency {dep} provided also") + self.parser.error( + f"API {api} needs dependency {dep} provided also" + ) return api_inputs.append( @@ -106,13 +116,17 @@ class StackBuild(Subcommand): ) docker_image = None else: - if args.api_providers: - self.parser.error("You cannot specify API providers for pre-registered distributions") + if build_config.api_providers: + self.parser.error( + "You cannot specify API providers for pre-registered distributions" + ) return - dist = resolve_distribution_spec(args.distribution) + dist = resolve_distribution_spec(build_config.distribution) if dist is None: - self.parser.error(f"Could not find distribution {args.distribution}") + self.parser.error( + f"Could not find distribution {build_config.distribution}" + ) return for api, provider_type in dist.providers.items(): @@ -126,8 +140,47 @@ class StackBuild(Subcommand): build_package( api_inputs, - build_type=BuildType(args.type), - name=args.name, - distribution_type=args.distribution, + build_type=BuildType(build_config.package_type), + name=build_config.name, + distribution_type=build_config.distribution, docker_image=docker_image, ) + + # save build.yaml spec for building same distribution again + build_dir = ( + DISTRIBS_BASE_DIR + / build_config.distribution + / BuildType(build_config.package_type).descriptor() + ) + os.makedirs(build_dir, exist_ok=True) + build_file_path = build_dir / f"{build_config.name}-build.yaml" + + with open(build_file_path, "w") as f: + to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) + + cprint( + f"Build spec configuration saved at {str(build_file_path)}", + color="green", + ) + + def _run_stack_build_command(self, args: argparse.Namespace) -> None: + if args.config_file: + with open(args.config_file, "r") as f: + try: + build_config = BuildConfig(**yaml.safe_load(f)) + except Exception as e: + self.parser.error( + f"Could not parse config file {args.config_file}: {e}" + ) + return + self._run_stack_build_command_from_build_config(build_config) + return + + build_config = BuildConfig( + name=args.name, + distribution_type=args.distribution, + package_type=args.package_type, + api_providers=args.api_providers, + ) + self._run_stack_build_command_from_build_config(build_config) diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index 138d20941..e927cee1d 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -188,3 +188,19 @@ Provider configurations for each of the APIs provided by this package. This incl the dependencies of these providers as well. """, ) + + +@json_schema_type +class BuildConfig(BaseModel): + name: str + distribution: str = Field( + default="local", description="Type of distribution to build (adhoc | {})" + ) + api_providers: Optional[str] = Field( + default_factory=list, + description="List of API provider names to build", + ) + package_type: str = Field( + default="conda_env", + description="Type of package to build (conda_env | container)", + )