From fddaf5c929e2b55f3615dcb7cc9a36248fb30655 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 7 Aug 2024 15:27:00 -0700 Subject: [PATCH] Refactor download functionality out of the Command so can be reused --- llama_toolchain/cli/download.py | 147 +++++++++++++++++--------------- 1 file changed, 77 insertions(+), 70 deletions(-) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index b8ade9b14..401bc633c 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -9,6 +9,7 @@ import asyncio import os import shutil import time +from functools import partial from pathlib import Path import httpx @@ -30,7 +31,7 @@ class Download(Subcommand): formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() - self.parser.set_defaults(func=self._run_download_cmd) + self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser)) def _add_arguments(self): from llama_models.sku_list import all_registered_models @@ -70,79 +71,85 @@ safetensors files to avoid downloading duplicate weights. """, ) - def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str): - from huggingface_hub import snapshot_download - from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - from llama_toolchain.common.model_utils import model_local_dir +def _hf_download( + model: "Model", + hf_token: str, + ignore_patterns: str, + parser: argparse.ArgumentParser, +): + from huggingface_hub import snapshot_download + from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - repo_id = model.huggingface_repo - if repo_id is None: - raise ValueError(f"No repo id found for model {model.descriptor()}") + from llama_toolchain.common.model_utils import model_local_dir - output_dir = model_local_dir(model) - os.makedirs(output_dir, exist_ok=True) - try: - true_output_dir = snapshot_download( - repo_id, - local_dir=output_dir, - ignore_patterns=ignore_patterns, - token=hf_token, - library_name="llama-toolchain", + repo_id = model.huggingface_repo + if repo_id is None: + raise ValueError(f"No repo id found for model {model.descriptor()}") + + output_dir = model_local_dir(model) + os.makedirs(output_dir, exist_ok=True) + try: + true_output_dir = snapshot_download( + repo_id, + local_dir=output_dir, + ignore_patterns=ignore_patterns, + token=hf_token, + library_name="llama-toolchain", + ) + except GatedRepoError: + parser.error( + "It looks like you are trying to access a gated repository. Please ensure you " + "have access to the repository and have provided the proper Hugging Face API token " + "using the option `--hf-token` or by running `huggingface-cli login`." + "You can find your token by visiting https://huggingface.co/settings/tokens" + ) + except RepositoryNotFoundError: + parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.") + except Exception as e: + parser.error(e) + + print(f"\nSuccessfully downloaded model to {true_output_dir}") + + +def _meta_download(model: "Model", meta_url: str): + from llama_models.sku_list import llama_meta_net_info + + from llama_toolchain.common.model_utils import model_local_dir + + output_dir = Path(model_local_dir(model)) + os.makedirs(output_dir, exist_ok=True) + + info = llama_meta_net_info(model) + + # I believe we can use some concurrency here if needed but not sure it is worth it + for f in info.files: + output_file = str(output_dir / f) + url = meta_url.replace("*", f"{info.folder}/{f}") + total_size = info.pth_size if "consolidated" in f else 0 + cprint(f"Downloading `{f}`...", "white") + downloader = ResumableDownloader(url, output_file, total_size) + asyncio.run(downloader.download()) + + +def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): + from llama_models.sku_list import resolve_model + + model = resolve_model(args.model_id) + if model is None: + parser.error(f"Model {args.model_id} not found") + return + + if args.source == "huggingface": + _hf_download(model, args.hf_token, args.ignore_patterns, parser) + else: + meta_url = args.meta_url + if not meta_url: + meta_url = input( + "Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): " ) - except GatedRepoError: - self.parser.error( - "It looks like you are trying to access a gated repository. Please ensure you " - "have access to the repository and have provided the proper Hugging Face API token " - "using the option `--hf-token` or by running `huggingface-cli login`." - "You can find your token by visiting https://huggingface.co/settings/tokens" - ) - except RepositoryNotFoundError: - self.parser.error( - f"Repository '{args.repo_id}' not found on the Hugging Face Hub." - ) - except Exception as e: - self.parser.error(e) - - print(f"Successfully downloaded model to {true_output_dir}") - - def _meta_download(self, model: "Model", meta_url: str): - from llama_models.sku_list import llama_meta_net_info - - from llama_toolchain.common.model_utils import model_local_dir - - output_dir = Path(model_local_dir(model)) - os.makedirs(output_dir, exist_ok=True) - - info = llama_meta_net_info(model) - - # I believe we can use some concurrency here if needed but not sure it is worth it - for f in info.files: - output_file = str(output_dir / f) - url = meta_url.replace("*", f"{info.folder}/{f}") - total_size = info.pth_size if "consolidated" in f else 0 - cprint(f"Downloading `{f}`...", "white") - downloader = ResumableDownloader(url, output_file, total_size) - asyncio.run(downloader.download()) - - def _run_download_cmd(self, args: argparse.Namespace): - from llama_models.sku_list import resolve_model - - model = resolve_model(args.model_id) - if model is None: - self.parser.error(f"Model {args.model_id} not found") - return - - if args.source == "huggingface": - self._hf_download(model, args.hf_token, args.ignore_patterns) - else: - meta_url = args.meta_url - if not meta_url: - meta_url = input( - "Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): " - ) - assert meta_url is not None and "llama3-1.llamameta.net" in meta_url - self._meta_download(model, meta_url) + assert meta_url is not None and "llama3-1.llamameta.net" in meta_url + _meta_download(model, meta_url) class ResumableDownloader: