diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 7ae5c1378..615c2ce04 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -58,7 +58,52 @@ class StackConfigure(Subcommand): self._configure_llama_distribution(build_config) - def _configure_llama_distribution(self, build_config: BuildConfig) -> None: + def _configure_llama_distribution(self, build_config: BuildConfig): + from llama_toolchain.common.serialize import EnumEncoder + from llama_toolchain.core.configure import configure_api_providers + + builds_dir = BUILDS_BASE_DIR / build_config.image_type + os.makedirs(builds_dir, exist_ok=True) + package_name = build_config.name.replace("::", "-") + package_file = builds_dir / f"{package_name}-run.yaml" + + api2providers = build_config.distribution_spec.providers + + stub_config = { + api_str: {"provider_type": provider} + for api_str, provider in api2providers.items() + } + + if package_file.exists(): + cprint( + f"Configuration already exists for {build_config.distribution}. Will overwrite...", + "yellow", + attrs=["bold"], + ) + config = PackageConfig(**yaml.safe_load(package_file.read_text())) + else: + config = PackageConfig( + built_at=datetime.now(), + package_name=package_name, + providers=stub_config, + ) + + config.providers = configure_api_providers(config.providers) + config.distribution_type = build_config.distribution_spec.distribution_type + config.docker_image = ( + package_name if build_config.image_type == "docker" else None + ) + config.conda_env = package_name if build_config.image_type == "conda" else None + + with open(package_file, "w") as f: + to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) + + print(f"YAML configuration has been written to {package_file}") + + def _configure_llama_distribution_DEPRECATED( + self, build_config: BuildConfig + ) -> None: from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.core.configure import configure_api_providers from llama_toolchain.core.distribution import api_providers