mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Refactor download functionality out of the Command so can be reused
This commit is contained in:
parent
68654460f8
commit
fddaf5c929
1 changed files with 77 additions and 70 deletions
|
@ -9,6 +9,7 @@ import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -30,7 +31,7 @@ class Download(Subcommand):
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
self.parser.set_defaults(func=self._run_download_cmd)
|
self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser))
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
@ -70,79 +71,85 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
def _hf_download(
|
||||||
|
model: "Model",
|
||||||
|
hf_token: str,
|
||||||
|
ignore_patterns: str,
|
||||||
|
parser: argparse.ArgumentParser,
|
||||||
|
):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
repo_id = model.huggingface_repo
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
if repo_id is None:
|
|
||||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
|
||||||
|
|
||||||
output_dir = model_local_dir(model)
|
repo_id = model.huggingface_repo
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
if repo_id is None:
|
||||||
try:
|
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||||
true_output_dir = snapshot_download(
|
|
||||||
repo_id,
|
output_dir = model_local_dir(model)
|
||||||
local_dir=output_dir,
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
ignore_patterns=ignore_patterns,
|
try:
|
||||||
token=hf_token,
|
true_output_dir = snapshot_download(
|
||||||
library_name="llama-toolchain",
|
repo_id,
|
||||||
|
local_dir=output_dir,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
token=hf_token,
|
||||||
|
library_name="llama-toolchain",
|
||||||
|
)
|
||||||
|
except GatedRepoError:
|
||||||
|
parser.error(
|
||||||
|
"It looks like you are trying to access a gated repository. Please ensure you "
|
||||||
|
"have access to the repository and have provided the proper Hugging Face API token "
|
||||||
|
"using the option `--hf-token` or by running `huggingface-cli login`."
|
||||||
|
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||||
|
)
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.")
|
||||||
|
except Exception as e:
|
||||||
|
parser.error(e)
|
||||||
|
|
||||||
|
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_download(model: "Model", meta_url: str):
|
||||||
|
from llama_models.sku_list import llama_meta_net_info
|
||||||
|
|
||||||
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
|
output_dir = Path(model_local_dir(model))
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
info = llama_meta_net_info(model)
|
||||||
|
|
||||||
|
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||||
|
for f in info.files:
|
||||||
|
output_file = str(output_dir / f)
|
||||||
|
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||||
|
total_size = info.pth_size if "consolidated" in f else 0
|
||||||
|
cprint(f"Downloading `{f}`...", "white")
|
||||||
|
downloader = ResumableDownloader(url, output_file, total_size)
|
||||||
|
asyncio.run(downloader.download())
|
||||||
|
|
||||||
|
|
||||||
|
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
model = resolve_model(args.model_id)
|
||||||
|
if model is None:
|
||||||
|
parser.error(f"Model {args.model_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.source == "huggingface":
|
||||||
|
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||||
|
else:
|
||||||
|
meta_url = args.meta_url
|
||||||
|
if not meta_url:
|
||||||
|
meta_url = input(
|
||||||
|
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||||
)
|
)
|
||||||
except GatedRepoError:
|
assert meta_url is not None and "llama3-1.llamameta.net" in meta_url
|
||||||
self.parser.error(
|
_meta_download(model, meta_url)
|
||||||
"It looks like you are trying to access a gated repository. Please ensure you "
|
|
||||||
"have access to the repository and have provided the proper Hugging Face API token "
|
|
||||||
"using the option `--hf-token` or by running `huggingface-cli login`."
|
|
||||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
|
||||||
)
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
self.parser.error(
|
|
||||||
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.parser.error(e)
|
|
||||||
|
|
||||||
print(f"Successfully downloaded model to {true_output_dir}")
|
|
||||||
|
|
||||||
def _meta_download(self, model: "Model", meta_url: str):
|
|
||||||
from llama_models.sku_list import llama_meta_net_info
|
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
|
||||||
|
|
||||||
output_dir = Path(model_local_dir(model))
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
info = llama_meta_net_info(model)
|
|
||||||
|
|
||||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
|
||||||
for f in info.files:
|
|
||||||
output_file = str(output_dir / f)
|
|
||||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
|
||||||
total_size = info.pth_size if "consolidated" in f else 0
|
|
||||||
cprint(f"Downloading `{f}`...", "white")
|
|
||||||
downloader = ResumableDownloader(url, output_file, total_size)
|
|
||||||
asyncio.run(downloader.download())
|
|
||||||
|
|
||||||
def _run_download_cmd(self, args: argparse.Namespace):
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
model = resolve_model(args.model_id)
|
|
||||||
if model is None:
|
|
||||||
self.parser.error(f"Model {args.model_id} not found")
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.source == "huggingface":
|
|
||||||
self._hf_download(model, args.hf_token, args.ignore_patterns)
|
|
||||||
else:
|
|
||||||
meta_url = args.meta_url
|
|
||||||
if not meta_url:
|
|
||||||
meta_url = input(
|
|
||||||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
|
||||||
)
|
|
||||||
assert meta_url is not None and "llama3-1.llamameta.net" in meta_url
|
|
||||||
self._meta_download(model, meta_url)
|
|
||||||
|
|
||||||
|
|
||||||
class ResumableDownloader:
|
class ResumableDownloader:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue