mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
added resumable downloader for downloading models
This commit is contained in:
parent
59574924de
commit
040c30ee54
1 changed files with 167 additions and 25 deletions
|
@ -5,13 +5,27 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import textwrap
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
CheckpointQuantizationFormat,
|
||||
ModelDefinition,
|
||||
)
|
||||
from llama_models.llama3_1.api.sku_list import (
|
||||
llama3_1_model_list,
|
||||
llama_meta_folder_path,
|
||||
)
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||
|
||||
|
@ -27,55 +41,61 @@ class Download(Subcommand):
|
|||
self.parser = subparsers.add_parser(
|
||||
"download",
|
||||
prog="llama download",
|
||||
description="Download a model from the Hugging Face Hub",
|
||||
epilog=textwrap.dedent(
|
||||
"""\
|
||||
# Here are some examples on how to use this command:
|
||||
|
||||
llama download --repo-id meta-llama/Llama-2-7b-hf --hf-token <HF_TOKEN>
|
||||
llama download --repo-id meta-llama/Llama-2-7b-hf --output-dir /data/my_custom_dir --hf-token <HF_TOKEN>
|
||||
HF_TOKEN=<HF_TOKEN> llama download --repo-id meta-llama/Llama-2-7b-hf
|
||||
|
||||
The output directory will be used to load models and tokenizers for inference.
|
||||
"""
|
||||
),
|
||||
description="Download a model from llama.meta.comf or HuggingFace hub",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_download_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
models = llama3_1_model_list()
|
||||
self.parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="Name of the repository on Hugging Face Hub eg. meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"--source",
|
||||
choices=["meta", "huggingface"],
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--model-id",
|
||||
choices=[x.sku.value for x in models],
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--hf-token",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.",
|
||||
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--meta-url",
|
||||
type=str,
|
||||
required=False,
|
||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--ignore-patterns",
|
||||
type=str,
|
||||
required=False,
|
||||
default="*.safetensors",
|
||||
help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring "
|
||||
"safetensors files to avoid downloading duplicate weights.",
|
||||
help="""
|
||||
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
||||
safetensors files to avoid downloading duplicate weights.
|
||||
""",
|
||||
)
|
||||
|
||||
def _run_download_cmd(self, args: argparse.Namespace):
|
||||
model_name = args.repo_id.split("/")[-1]
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name
|
||||
def _hf_download(self, model: ModelDefinition, hf_token: str, ignore_patterns: str):
|
||||
repo_id = model.huggingface_id
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.sku.value}")
|
||||
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.sku.value
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
args.repo_id,
|
||||
repo_id,
|
||||
local_dir=output_dir,
|
||||
ignore_patterns=args.ignore_patterns,
|
||||
token=args.hf_token,
|
||||
ignore_patterns=ignore_patterns,
|
||||
token=hf_token,
|
||||
library_name="llama-toolchain",
|
||||
)
|
||||
except GatedRepoError:
|
||||
|
@ -93,3 +113,125 @@ class Download(Subcommand):
|
|||
self.parser.error(e)
|
||||
|
||||
print(f"Successfully downloaded model to {true_output_dir}")
|
||||
|
||||
def _meta_download(self, model: ModelDefinition, meta_url: str):
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.sku.value
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
gpus = model.hardware_requirements.gpu_count
|
||||
files = [
|
||||
"tokenizer.model",
|
||||
]
|
||||
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
|
||||
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)])
|
||||
|
||||
folder_path = llama_meta_folder_path(model)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
for f in files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{folder_path}/{f}")
|
||||
downloader = ResumableDownloader(url, output_file)
|
||||
asyncio.run(downloader.download())
|
||||
|
||||
def _run_download_cmd(self, args: argparse.Namespace):
|
||||
by_id = {model.sku.value: model for model in llama3_1_model_list()}
|
||||
assert args.model_id in by_id, f"Unexpected model id {args.model_id}"
|
||||
|
||||
model = by_id[args.model_id]
|
||||
if args.source == "huggingface":
|
||||
self._hf_download(model, args.hf_token, args.ignore_patterns)
|
||||
else:
|
||||
if not args.meta_url:
|
||||
self.parser.error(
|
||||
"Please provide a meta url to download the model from llama.meta.com"
|
||||
)
|
||||
self._meta_download(model, args.meta_url)
|
||||
|
||||
|
||||
class ResumableDownloader:
|
||||
def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024):
|
||||
self.url = url
|
||||
self.output_file = output_file
|
||||
self.buffer_size = buffer_size
|
||||
self.total_size: Optional[int] = None
|
||||
self.downloaded_size = 0
|
||||
self.start_size = 0
|
||||
self.start_time = 0
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
||||
response = await client.head(self.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
self.url = str(response.url) # Update URL in case of redirects
|
||||
self.total_size = int(response.headers.get("Content-Length", 0))
|
||||
if self.total_size == 0:
|
||||
raise ValueError(
|
||||
"Unable to determine file size. The server might not support range requests."
|
||||
)
|
||||
|
||||
async def download(self) -> None:
|
||||
self.start_time = time.time()
|
||||
async with httpx.AsyncClient() as client:
|
||||
await self.get_file_info(client)
|
||||
|
||||
if os.path.exists(self.output_file):
|
||||
self.downloaded_size = os.path.getsize(self.output_file)
|
||||
self.start_size = self.downloaded_size
|
||||
if self.downloaded_size >= self.total_size:
|
||||
print(f"Already downloaded `{self.output_file}`, skipping...")
|
||||
return
|
||||
|
||||
additional_size = self.total_size - self.downloaded_size
|
||||
if not self.has_disk_space(additional_size):
|
||||
print(
|
||||
f"Not enough disk space to download `{self.output_file}`. "
|
||||
f"Required: {(additional_size / M):.2f} MB"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Not enough disk space to download `{self.output_file}`"
|
||||
)
|
||||
|
||||
headers = {"Range": f"bytes={self.downloaded_size}-"}
|
||||
try:
|
||||
async with client.stream("GET", self.url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
with open(self.output_file, "ab") as file:
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
self.downloaded_size += len(chunk)
|
||||
self.print_progress()
|
||||
except httpx.HTTPError as e:
|
||||
print(f"\nDownload interrupted: {e}")
|
||||
print("You can resume the download by running the script again.")
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
else:
|
||||
print(f"\nFinished downloading `{self.output_file}`....")
|
||||
|
||||
def print_progress(self) -> None:
|
||||
percent = (self.downloaded_size / self.total_size) * 100
|
||||
bar_length = 50
|
||||
filled_length = int(bar_length * self.downloaded_size // self.total_size)
|
||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||
|
||||
elapsed_time = time.time() - self.start_time
|
||||
M = 1024 * 1024 # noqa
|
||||
|
||||
speed = (
|
||||
(self.downloaded_size - self.start_size) / (elapsed_time * M)
|
||||
if elapsed_time > 0
|
||||
else 0
|
||||
)
|
||||
print(
|
||||
f"\rProgress: |{bar}| {percent:.2f}% "
|
||||
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
|
||||
f"Speed: {speed:.2f} MiB/s",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def has_disk_space(self, file_size: int) -> bool:
|
||||
dir_path = os.path.dirname(os.path.abspath(self.output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
return free_space > file_size
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue