mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Refactor download functionality out of the Command so can be reused
This commit is contained in:
parent
68654460f8
commit
fddaf5c929
1 changed files with 77 additions and 70 deletions
|
@ -9,6 +9,7 @@ import asyncio
|
|||
import os
|
||||
import shutil
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
@ -30,7 +31,7 @@ class Download(Subcommand):
|
|||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_download_cmd)
|
||||
self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser))
|
||||
|
||||
def _add_arguments(self):
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
@ -70,79 +71,85 @@ safetensors files to avoid downloading duplicate weights.
|
|||
""",
|
||||
)
|
||||
|
||||
def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
def _hf_download(
|
||||
model: "Model",
|
||||
hf_token: str,
|
||||
ignore_patterns: str,
|
||||
parser: argparse.ArgumentParser,
|
||||
):
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = model_local_dir(model)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
repo_id,
|
||||
local_dir=output_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
token=hf_token,
|
||||
library_name="llama-toolchain",
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
||||
output_dir = model_local_dir(model)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
repo_id,
|
||||
local_dir=output_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
token=hf_token,
|
||||
library_name="llama-toolchain",
|
||||
)
|
||||
except GatedRepoError:
|
||||
parser.error(
|
||||
"It looks like you are trying to access a gated repository. Please ensure you "
|
||||
"have access to the repository and have provided the proper Hugging Face API token "
|
||||
"using the option `--hf-token` or by running `huggingface-cli login`."
|
||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.")
|
||||
except Exception as e:
|
||||
parser.error(e)
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
cprint(f"Downloading `{f}`...", "white")
|
||||
downloader = ResumableDownloader(url, output_file, total_size)
|
||||
asyncio.run(downloader.download())
|
||||
|
||||
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
model = resolve_model(args.model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {args.model_id} not found")
|
||||
return
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url
|
||||
if not meta_url:
|
||||
meta_url = input(
|
||||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
except GatedRepoError:
|
||||
self.parser.error(
|
||||
"It looks like you are trying to access a gated repository. Please ensure you "
|
||||
"have access to the repository and have provided the proper Hugging Face API token "
|
||||
"using the option `--hf-token` or by running `huggingface-cli login`."
|
||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
self.parser.error(
|
||||
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
|
||||
)
|
||||
except Exception as e:
|
||||
self.parser.error(e)
|
||||
|
||||
print(f"Successfully downloaded model to {true_output_dir}")
|
||||
|
||||
def _meta_download(self, model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
cprint(f"Downloading `{f}`...", "white")
|
||||
downloader = ResumableDownloader(url, output_file, total_size)
|
||||
asyncio.run(downloader.download())
|
||||
|
||||
def _run_download_cmd(self, args: argparse.Namespace):
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
model = resolve_model(args.model_id)
|
||||
if model is None:
|
||||
self.parser.error(f"Model {args.model_id} not found")
|
||||
return
|
||||
|
||||
if args.source == "huggingface":
|
||||
self._hf_download(model, args.hf_token, args.ignore_patterns)
|
||||
else:
|
||||
meta_url = args.meta_url
|
||||
if not meta_url:
|
||||
meta_url = input(
|
||||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
assert meta_url is not None and "llama3-1.llamameta.net" in meta_url
|
||||
self._meta_download(model, meta_url)
|
||||
assert meta_url is not None and "llama3-1.llamameta.net" in meta_url
|
||||
_meta_download(model, meta_url)
|
||||
|
||||
|
||||
class ResumableDownloader:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue