mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
More progress towards llama distribution install
This commit is contained in:
parent
5a583cf16e
commit
dac2b5a1ed
11 changed files with 298 additions and 75 deletions
|
@ -6,6 +6,8 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -13,10 +15,12 @@ import pkg_resources
|
|||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import all_registered_distributions
|
||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
|
||||
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
|
||||
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
|
||||
|
||||
DISTRIBS = all_registered_distributions()
|
||||
|
||||
|
||||
class DistributionInstall(Subcommand):
|
||||
|
@ -34,59 +38,45 @@ class DistributionInstall(Subcommand):
|
|||
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
distribs = all_registered_distributions()
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Mame of the distribution to install",
|
||||
default="local-source",
|
||||
choices=[d.name for d in distribs],
|
||||
help="Name of the distribution to install -- (try local-ollama)",
|
||||
required=True,
|
||||
choices=[d.name for d in DISTRIBS],
|
||||
)
|
||||
|
||||
def read_user_inputs(self):
|
||||
checkpoint_dir = input(
|
||||
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
|
||||
self.parser.add_argument(
|
||||
"--conda-env",
|
||||
type=str,
|
||||
help="Specify the name of the conda environment you would like to create or update",
|
||||
required=True,
|
||||
)
|
||||
model_parallel_size = input(
|
||||
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
|
||||
)
|
||||
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
|
||||
1,
|
||||
8,
|
||||
}, "model parallel size must be 1 or 8"
|
||||
|
||||
return checkpoint_dir, model_parallel_size
|
||||
|
||||
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
|
||||
default_conf_path = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "data/default_distribution_config.yaml"
|
||||
)
|
||||
with open(default_conf_path, "r") as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
yaml_content = yaml_content.format(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
model_parallel_size=model_parallel_size,
|
||||
)
|
||||
|
||||
with open(yaml_output_path, "w") as yaml_file:
|
||||
yaml_file.write(yaml_content.strip())
|
||||
|
||||
print(f"YAML configuration has been written to {yaml_output_path}")
|
||||
|
||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
||||
checkpoint_dir, model_parallel_size = self.read_user_inputs()
|
||||
checkpoint_dir = os.path.expanduser(checkpoint_dir)
|
||||
|
||||
assert (
|
||||
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
|
||||
), f"{checkpoint_dir} does not exist or it not a directory"
|
||||
|
||||
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
|
||||
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml"
|
||||
|
||||
self.write_output_yaml(
|
||||
checkpoint_dir,
|
||||
model_parallel_size,
|
||||
yaml_output_path,
|
||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain",
|
||||
"distribution/install_distribution.sh",
|
||||
)
|
||||
|
||||
dist = None
|
||||
for d in DISTRIBS:
|
||||
if d.name == args.name:
|
||||
dist = d
|
||||
break
|
||||
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.name}")
|
||||
return
|
||||
|
||||
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
|
||||
run_shell_script(script, args.conda_env, " ".join(dist.pip_packages))
|
||||
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
|
||||
f.write(f"{args.conda_env}\n")
|
||||
|
||||
|
||||
def run_shell_script(script_path, *args):
|
||||
command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}"
|
||||
command_list = shlex.split(command_string)
|
||||
print(f"Running command: {command_list}")
|
||||
subprocess.run(command_list, check=True, text=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue