From 23fe353e4a13b32b01ecbe3568552d8ce4de6171 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sun, 21 Jul 2024 01:16:44 -0700 Subject: [PATCH] cli -- llama inference configure --- toolchain/cli/download.py | 27 +++++----- toolchain/cli/inference/configure.py | 80 ++++++++++++++++++++++++++++ toolchain/cli/inference/inference.py | 2 + 3 files changed, 94 insertions(+), 15 deletions(-) create mode 100644 toolchain/cli/inference/configure.py diff --git a/toolchain/cli/download.py b/toolchain/cli/download.py index a905cca55..10889992e 100644 --- a/toolchain/cli/download.py +++ b/toolchain/cli/download.py @@ -9,7 +9,7 @@ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from toolchain.cli.subcommand import Subcommand -DEFAULT_OUTPUT_DIR = "/tmp/llama_toolchain_cache/" +DEFAULT_CHECKPOINT_DIR = f"{os.path.expanduser('~')}/.llama/checkpoints/" class Download(Subcommand): @@ -43,13 +43,6 @@ class Download(Subcommand): type=str, help="Name of the repository on Hugging Face Hub eg. llhf/Meta-Llama-3.1-70B-Instruct", ) - self.parser.add_argument( - "--output-dir", - type=Path, - required=False, - default=None, - help=f"Directory in which to save the model. Defaults to `{DEFAULT_OUTPUT_DIR}`.", - ) self.parser.add_argument( "--hf-token", type=str, @@ -57,18 +50,21 @@ class Download(Subcommand): default=os.getenv("HF_TOKEN", None), help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.", ) + self.parser.add_argument( + "--ignore-patterns", + type=str, + required=False, + default="*.safetensors", + help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring " + "safetensors files to avoid downloading duplicate weights.", + ) def _run_download_cmd(self, args: argparse.Namespace): model_name = args.repo_id.split("/")[-1] - os.makedirs(DEFAULT_OUTPUT_DIR, exist_ok=True) - output_dir = args.output_dir - model_name = args.repo_id.split("/")[-1] - if output_dir is None: - output_dir = Path(DEFAULT_OUTPUT_DIR) / model_name - else: - output_dir = Path(output_dir) / model_name + os.makedirs(output_dir, exist_ok=True) + output_dir = Path(output_dir) / model_name try: true_output_dir = snapshot_download( args.repo_id, @@ -76,6 +72,7 @@ class Download(Subcommand): # "auto" will download to cache_dir and symlink files to local_dir # avoiding unnecessary duplicate copies local_dir_use_symlinks="auto", + ignore_patterns=args.ignore_patterns, token=args.hf_token, ) except GatedRepoError: diff --git a/toolchain/cli/inference/configure.py b/toolchain/cli/inference/configure.py new file mode 100644 index 000000000..9c728f6d8 --- /dev/null +++ b/toolchain/cli/inference/configure.py @@ -0,0 +1,80 @@ +import argparse +import os +import textwrap + +from pathlib import Path + +from toolchain.cli.subcommand import Subcommand + + +CONFIGS_BASE_DIR = f"{os.path.expanduser('~')}/.llama/configs/" + +class InferenceConfigure(Subcommand): + """Llama cli for configuring llama toolchain configs""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "configure", + prog="llama inference configure", + description="Configure llama toolchain inference configs", + epilog=textwrap.dedent( + """ + Example: + llama inference configure + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_inference_configure_cmd) + + def _add_arguments(self): + pass + + def read_user_inputs(self): + checkpoint_dir = input("Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): ") + model_parallel_size = input("Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): ") + + return checkpoint_dir, model_parallel_size + + def write_output_yaml( + self, + checkpoint_dir, + model_parallel_size, + yaml_output_path + ): + yaml_content = textwrap.dedent(f""" + model_inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: {checkpoint_dir}/ + tokenizer_path: {checkpoint_dir}/tokenizer.model + model_parallel_size: {model_parallel_size} + max_seq_len: 2048 + max_batch_size: 1 + """) + with open(yaml_output_path, 'w') as yaml_file: + yaml_file.write(yaml_content.strip()) + + print(f"YAML configuration has been written to {yaml_output_path}") + + def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None: + checkpoint_dir, model_parallel_size = self.read_user_inputs() + checkpoint_dir = os.path.expanduser(checkpoint_dir) + + if not ( + checkpoint_dir.endswith("original") or + checkpoint_dir.endswith("original/") + ): + checkpoint_dir = os.path.join(checkpoint_dir, "original") + + os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) + yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml" + + self.write_output_yaml( + checkpoint_dir, + model_parallel_size, + yaml_output_path, + ) diff --git a/toolchain/cli/inference/inference.py b/toolchain/cli/inference/inference.py index d8cf7b1ba..b3713b4af 100644 --- a/toolchain/cli/inference/inference.py +++ b/toolchain/cli/inference/inference.py @@ -1,6 +1,7 @@ import argparse import textwrap +from toolchain.cli.inference.configure import InferenceConfigure from toolchain.cli.inference.start import InferenceStart from toolchain.cli.subcommand import Subcommand @@ -26,3 +27,4 @@ class InferenceParser(Subcommand): # Add sub-commandsa InferenceStart.create(subparsers) + InferenceConfigure.create(subparsers)