Refactor download functionality out of the Command so can be reused

This commit is contained in:
Ashwin Bharambe 2024-08-07 15:27:00 -07:00
parent 68654460f8
commit fddaf5c929

View file

@ -9,6 +9,7 @@ import asyncio
import os import os
import shutil import shutil
import time import time
from functools import partial
from pathlib import Path from pathlib import Path
import httpx import httpx
@ -30,7 +31,7 @@ class Download(Subcommand):
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
self._add_arguments() self._add_arguments()
self.parser.set_defaults(func=self._run_download_cmd) self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser))
def _add_arguments(self): def _add_arguments(self):
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
@ -70,7 +71,13 @@ safetensors files to avoid downloading duplicate weights.
""", """,
) )
def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
):
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
@ -91,22 +98,21 @@ safetensors files to avoid downloading duplicate weights.
library_name="llama-toolchain", library_name="llama-toolchain",
) )
except GatedRepoError: except GatedRepoError:
self.parser.error( parser.error(
"It looks like you are trying to access a gated repository. Please ensure you " "It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token " "have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`." "using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens" "You can find your token by visiting https://huggingface.co/settings/tokens"
) )
except RepositoryNotFoundError: except RepositoryNotFoundError:
self.parser.error( parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.")
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
)
except Exception as e: except Exception as e:
self.parser.error(e) parser.error(e)
print(f"Successfully downloaded model to {true_output_dir}") print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(self, model: "Model", meta_url: str):
def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.common.model_utils import model_local_dir
@ -125,16 +131,17 @@ safetensors files to avoid downloading duplicate weights.
downloader = ResumableDownloader(url, output_file, total_size) downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download()) asyncio.run(downloader.download())
def _run_download_cmd(self, args: argparse.Namespace):
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
model = resolve_model(args.model_id) model = resolve_model(args.model_id)
if model is None: if model is None:
self.parser.error(f"Model {args.model_id} not found") parser.error(f"Model {args.model_id} not found")
return return
if args.source == "huggingface": if args.source == "huggingface":
self._hf_download(model, args.hf_token, args.ignore_patterns) _hf_download(model, args.hf_token, args.ignore_patterns, parser)
else: else:
meta_url = args.meta_url meta_url = args.meta_url
if not meta_url: if not meta_url:
@ -142,7 +149,7 @@ safetensors files to avoid downloading duplicate weights.
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): " "Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
) )
assert meta_url is not None and "llama3-1.llamameta.net" in meta_url assert meta_url is not None and "llama3-1.llamameta.net" in meta_url
self._meta_download(model, meta_url) _meta_download(model, meta_url)
class ResumableDownloader: class ResumableDownloader: