add DistributionConfig, fix a bug in model download

This commit is contained in:
Ashwin Bharambe 2024-08-06 19:24:52 -07:00
parent ade574a0ef
commit 9e1ca4eeb1
4 changed files with 35 additions and 39 deletions

View file

@ -6,7 +6,6 @@
import argparse
import os
import shlex
import textwrap
import pkg_resources
@ -56,6 +55,7 @@ class DistributionInstall(Subcommand):
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
@ -86,12 +86,12 @@ class DistributionInstall(Subcommand):
config_file = distrib_dir / "config.yaml"
with open(config_file, "w") as f:
c = {
"conda_env": conda_env,
"spec": dist.spec_id,
"name": args.name,
}
f.write(yaml.dump(c))
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
@ -105,4 +105,5 @@ class DistributionInstall(Subcommand):
conda deactivate && conda activate {conda_env}
llama distribution configure --name {args.name}
"""
))
)
)