Refactor download functionality out of the Command so can be reused

This commit is contained in:
Ashwin Bharambe 2024-08-07 15:27:00 -07:00
parent 68654460f8
commit fddaf5c929

View file

@ -9,6 +9,7 @@ import asyncio
import os
import shutil
import time
from functools import partial
from pathlib import Path
import httpx
@ -30,7 +31,7 @@ class Download(Subcommand):
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_download_cmd)
self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser))
def _add_arguments(self):
from llama_models.sku_list import all_registered_models
@ -70,7 +71,13 @@ safetensors files to avoid downloading duplicate weights.
""",
)
def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
@ -91,22 +98,21 @@ safetensors files to avoid downloading duplicate weights.
library_name="llama-toolchain",
)
except GatedRepoError:
self.parser.error(
parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
self.parser.error(
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
)
parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.")
except Exception as e:
self.parser.error(e)
parser.error(e)
print(f"Successfully downloaded model to {true_output_dir}")
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(self, model: "Model", meta_url: str):
def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir
@ -125,16 +131,17 @@ safetensors files to avoid downloading duplicate weights.
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download())
def _run_download_cmd(self, args: argparse.Namespace):
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model
model = resolve_model(args.model_id)
if model is None:
self.parser.error(f"Model {args.model_id} not found")
parser.error(f"Model {args.model_id} not found")
return
if args.source == "huggingface":
self._hf_download(model, args.hf_token, args.ignore_patterns)
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
@ -142,7 +149,7 @@ safetensors files to avoid downloading duplicate weights.
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llama3-1.llamameta.net" in meta_url
self._meta_download(model, meta_url)
_meta_download(model, meta_url)
class ResumableDownloader: