From 040c30ee546cfbe5f6bb7aca709bd35215f5648a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 29 Jul 2024 07:41:07 -0700 Subject: [PATCH] added resumable downloader for downloading models --- llama_toolchain/cli/download.py | 192 +++++++++++++++++++++++++++----- 1 file changed, 167 insertions(+), 25 deletions(-) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index ed41ac505..ee774c67d 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -5,13 +5,27 @@ # the root directory of this source tree. import argparse +import asyncio import os -import textwrap +import shutil +import time from pathlib import Path +from typing import Optional + +import httpx from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError +from llama_models.llama3_1.api.datatypes import ( + CheckpointQuantizationFormat, + ModelDefinition, +) +from llama_models.llama3_1.api.sku_list import ( + llama3_1_model_list, + llama_meta_folder_path, +) + from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.utils import DEFAULT_DUMP_DIR @@ -27,55 +41,61 @@ class Download(Subcommand): self.parser = subparsers.add_parser( "download", prog="llama download", - description="Download a model from the Hugging Face Hub", - epilog=textwrap.dedent( - """\ - # Here are some examples on how to use this command: - - llama download --repo-id meta-llama/Llama-2-7b-hf --hf-token - llama download --repo-id meta-llama/Llama-2-7b-hf --output-dir /data/my_custom_dir --hf-token - HF_TOKEN= llama download --repo-id meta-llama/Llama-2-7b-hf - - The output directory will be used to load models and tokenizers for inference. - """ - ), + description="Download a model from llama.meta.comf or HuggingFace hub", formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() self.parser.set_defaults(func=self._run_download_cmd) def _add_arguments(self): + models = llama3_1_model_list() self.parser.add_argument( - "repo_id", - type=str, - help="Name of the repository on Hugging Face Hub eg. meta-llama/Meta-Llama-3.1-70B-Instruct", + "--source", + choices=["meta", "huggingface"], + required=True, + ) + self.parser.add_argument( + "--model-id", + choices=[x.sku.value for x in models], + required=True, ) self.parser.add_argument( "--hf-token", type=str, required=False, default=None, - help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.", + help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.", + ) + self.parser.add_argument( + "--meta-url", + type=str, + required=False, + help="For source=meta, URL obtained from llama.meta.com after accepting license terms", ) self.parser.add_argument( "--ignore-patterns", type=str, required=False, default="*.safetensors", - help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring " - "safetensors files to avoid downloading duplicate weights.", + help=""" +For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring +safetensors files to avoid downloading duplicate weights. +""", ) - def _run_download_cmd(self, args: argparse.Namespace): - model_name = args.repo_id.split("/")[-1] - output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name + def _hf_download(self, model: ModelDefinition, hf_token: str, ignore_patterns: str): + repo_id = model.huggingface_id + if repo_id is None: + raise ValueError(f"No repo id found for model {model.sku.value}") + + output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.sku.value os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( - args.repo_id, + repo_id, local_dir=output_dir, - ignore_patterns=args.ignore_patterns, - token=args.hf_token, + ignore_patterns=ignore_patterns, + token=hf_token, library_name="llama-toolchain", ) except GatedRepoError: @@ -93,3 +113,125 @@ class Download(Subcommand): self.parser.error(e) print(f"Successfully downloaded model to {true_output_dir}") + + def _meta_download(self, model: ModelDefinition, meta_url: str): + output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.sku.value + os.makedirs(output_dir, exist_ok=True) + + gpus = model.hardware_requirements.gpu_count + files = [ + "tokenizer.model", + ] + files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)]) + if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)]) + + folder_path = llama_meta_folder_path(model) + + # I believe we can use some concurrency here if needed but not sure it is worth it + for f in files: + output_file = str(output_dir / f) + url = meta_url.replace("*", f"{folder_path}/{f}") + downloader = ResumableDownloader(url, output_file) + asyncio.run(downloader.download()) + + def _run_download_cmd(self, args: argparse.Namespace): + by_id = {model.sku.value: model for model in llama3_1_model_list()} + assert args.model_id in by_id, f"Unexpected model id {args.model_id}" + + model = by_id[args.model_id] + if args.source == "huggingface": + self._hf_download(model, args.hf_token, args.ignore_patterns) + else: + if not args.meta_url: + self.parser.error( + "Please provide a meta url to download the model from llama.meta.com" + ) + self._meta_download(model, args.meta_url) + + +class ResumableDownloader: + def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024): + self.url = url + self.output_file = output_file + self.buffer_size = buffer_size + self.total_size: Optional[int] = None + self.downloaded_size = 0 + self.start_size = 0 + self.start_time = 0 + + async def get_file_info(self, client: httpx.AsyncClient) -> None: + response = await client.head(self.url, follow_redirects=True) + response.raise_for_status() + self.url = str(response.url) # Update URL in case of redirects + self.total_size = int(response.headers.get("Content-Length", 0)) + if self.total_size == 0: + raise ValueError( + "Unable to determine file size. The server might not support range requests." + ) + + async def download(self) -> None: + self.start_time = time.time() + async with httpx.AsyncClient() as client: + await self.get_file_info(client) + + if os.path.exists(self.output_file): + self.downloaded_size = os.path.getsize(self.output_file) + self.start_size = self.downloaded_size + if self.downloaded_size >= self.total_size: + print(f"Already downloaded `{self.output_file}`, skipping...") + return + + additional_size = self.total_size - self.downloaded_size + if not self.has_disk_space(additional_size): + print( + f"Not enough disk space to download `{self.output_file}`. " + f"Required: {(additional_size / M):.2f} MB" + ) + raise ValueError( + f"Not enough disk space to download `{self.output_file}`" + ) + + headers = {"Range": f"bytes={self.downloaded_size}-"} + try: + async with client.stream("GET", self.url, headers=headers) as response: + response.raise_for_status() + with open(self.output_file, "ab") as file: + async for chunk in response.aiter_bytes(self.buffer_size): + file.write(chunk) + self.downloaded_size += len(chunk) + self.print_progress() + except httpx.HTTPError as e: + print(f"\nDownload interrupted: {e}") + print("You can resume the download by running the script again.") + except Exception as e: + print(f"\nAn error occurred: {e}") + else: + print(f"\nFinished downloading `{self.output_file}`....") + + def print_progress(self) -> None: + percent = (self.downloaded_size / self.total_size) * 100 + bar_length = 50 + filled_length = int(bar_length * self.downloaded_size // self.total_size) + bar = "█" * filled_length + "-" * (bar_length - filled_length) + + elapsed_time = time.time() - self.start_time + M = 1024 * 1024 # noqa + + speed = ( + (self.downloaded_size - self.start_size) / (elapsed_time * M) + if elapsed_time > 0 + else 0 + ) + print( + f"\rProgress: |{bar}| {percent:.2f}% " + f"({self.downloaded_size // M}/{self.total_size // M} MB) " + f"Speed: {speed:.2f} MiB/s", + end="", + flush=True, + ) + + def has_disk_space(self, file_size: int) -> bool: + dir_path = os.path.dirname(os.path.abspath(self.output_file)) + free_space = shutil.disk_usage(dir_path).free + return free_space > file_size