# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import argparse import asyncio import json import os import shutil import time from datetime import datetime from functools import partial from pathlib import Path from typing import Dict, List import httpx from pydantic import BaseModel from termcolor import cprint from llama_stack.cli.subcommand import Subcommand class Download(Subcommand): """Llama cli for downloading llama toolchain assets""" def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( "download", prog="llama download", description="Download a model from llama.meta.com or Hugging Face Hub", formatter_class=argparse.RawTextHelpFormatter, ) setup_download_parser(self.parser) def setup_download_parser(parser: argparse.ArgumentParser) -> None: from llama_models.sku_list import all_registered_models models = all_registered_models() parser.add_argument( "--source", choices=["meta", "huggingface"], default="meta", ) parser.add_argument( "--model-id", required=False, help="See `llama model list` or `llama model list --show-all` for the list of available models", ) parser.add_argument( "--hf-token", type=str, required=False, default=None, help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.", ) parser.add_argument( "--meta-url", type=str, required=False, help="For source=meta, URL obtained from llama.meta.com after accepting license terms", ) parser.add_argument( "--ignore-patterns", type=str, required=False, default="*.safetensors", help=""" For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights. """, ) parser.add_argument( "--manifest-file", type=str, help="For source=meta, you can download models from a manifest file containing a file => URL mapping", required=False, ) parser.set_defaults(func=partial(run_download_cmd, parser=parser)) def _hf_download( model: "Model", hf_token: str, ignore_patterns: str, parser: argparse.ArgumentParser, ): from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from llama_stack.distribution.utils.model_utils import model_local_dir repo_id = model.huggingface_repo if repo_id is None: raise ValueError(f"No repo id found for model {model.descriptor()}") output_dir = model_local_dir(model.descriptor()) os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( repo_id, local_dir=output_dir, ignore_patterns=ignore_patterns, token=hf_token, library_name="llama-stack", ) except GatedRepoError: parser.error( "It looks like you are trying to access a gated repository. Please ensure you " "have access to the repository and have provided the proper Hugging Face API token " "using the option `--hf-token` or by running `huggingface-cli login`." "You can find your token by visiting https://huggingface.co/settings/tokens" ) except RepositoryNotFoundError: parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.") except Exception as e: parser.error(e) print(f"\nSuccessfully downloaded model to {true_output_dir}") def _meta_download(model: "Model", meta_url: str): from llama_models.sku_list import llama_meta_net_info from llama_stack.distribution.utils.model_utils import model_local_dir output_dir = Path(model_local_dir(model.descriptor())) os.makedirs(output_dir, exist_ok=True) info = llama_meta_net_info(model) # I believe we can use some concurrency here if needed but not sure it is worth it for f in info.files: output_file = str(output_dir / f) url = meta_url.replace("*", f"{info.folder}/{f}") total_size = info.pth_size if "consolidated" in f else 0 cprint(f"Downloading `{f}`...", "white") downloader = ResumableDownloader(url, output_file, total_size) asyncio.run(downloader.download()) print(f"\nSuccessfully downloaded model to {output_dir}") cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white") def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): from llama_models.sku_list import resolve_model if args.manifest_file: _download_from_manifest(args.manifest_file) return if args.model_id is None: parser.error("Please provide a model id") return model = resolve_model(args.model_id) if model is None: parser.error(f"Model {args.model_id} not found") return if args.source == "huggingface": _hf_download(model, args.hf_token, args.ignore_patterns, parser) else: meta_url = args.meta_url if not meta_url: meta_url = input( "Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): " ) assert meta_url is not None and "llamameta.net" in meta_url _meta_download(model, meta_url) class ModelEntry(BaseModel): model_id: str files: Dict[str, str] class Config: protected_namespaces = () class Manifest(BaseModel): models: List[ModelEntry] expires_on: datetime def _download_from_manifest(manifest_file: str): from llama_stack.distribution.utils.model_utils import model_local_dir with open(manifest_file, "r") as f: d = json.load(f) manifest = Manifest(**d) if datetime.now() > manifest.expires_on: raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") for entry in manifest.models: print(f"Downloading model {entry.model_id}...") output_dir = Path(model_local_dir(entry.model_id)) os.makedirs(output_dir, exist_ok=True) if any(output_dir.iterdir()): cprint(f"Output directory {output_dir} is not empty.", "red") while True: resp = input( "Do you want to (C)ontinue download or (R)estart completely? (continue/restart): " ) if resp.lower() == "restart" or resp.lower() == "r": shutil.rmtree(output_dir) os.makedirs(output_dir, exist_ok=True) break elif resp.lower() == "continue" or resp.lower() == "c": print("Continuing download...") break else: cprint("Invalid response. Please try again.", "red") for fname, url in entry.files.items(): output_file = str(output_dir / fname) downloader = ResumableDownloader(url, output_file) asyncio.run(downloader.download()) class ResumableDownloader: def __init__( self, url: str, output_file: str, total_size: int = 0, buffer_size: int = 32 * 1024, ): self.url = url self.output_file = output_file self.buffer_size = buffer_size self.total_size = total_size self.downloaded_size = 0 self.start_size = 0 self.start_time = 0 async def get_file_info(self, client: httpx.AsyncClient) -> None: if self.total_size > 0: return # Force disable compression when trying to retrieve file size response = await client.head( self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"} ) response.raise_for_status() self.url = str(response.url) # Update URL in case of redirects self.total_size = int(response.headers.get("Content-Length", 0)) if self.total_size == 0: raise ValueError( "Unable to determine file size. The server might not support range requests." ) async def download(self) -> None: self.start_time = time.time() async with httpx.AsyncClient(follow_redirects=True) as client: await self.get_file_info(client) if os.path.exists(self.output_file): self.downloaded_size = os.path.getsize(self.output_file) self.start_size = self.downloaded_size if self.downloaded_size >= self.total_size: print(f"Already downloaded `{self.output_file}`, skipping...") return additional_size = self.total_size - self.downloaded_size if not self.has_disk_space(additional_size): M = 1024 * 1024 # noqa print( f"Not enough disk space to download `{self.output_file}`. " f"Required: {(additional_size // M):.2f} MB" ) raise ValueError( f"Not enough disk space to download `{self.output_file}`" ) while True: if self.downloaded_size >= self.total_size: break # Cloudfront has a max-size limit max_chunk_size = 27_000_000_000 request_size = min( self.total_size - self.downloaded_size, max_chunk_size ) headers = { "Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}" } print(f"Downloading `{self.output_file}`....{headers}") try: async with client.stream( "GET", self.url, headers=headers ) as response: response.raise_for_status() with open(self.output_file, "ab") as file: async for chunk in response.aiter_bytes(self.buffer_size): file.write(chunk) self.downloaded_size += len(chunk) self.print_progress() except httpx.HTTPError as e: print(f"\nDownload interrupted: {e}") print("You can resume the download by running the script again.") except Exception as e: print(f"\nAn error occurred: {e}") print(f"\nFinished downloading `{self.output_file}`....") def print_progress(self) -> None: percent = (self.downloaded_size / self.total_size) * 100 bar_length = 50 filled_length = int(bar_length * self.downloaded_size // self.total_size) bar = "█" * filled_length + "-" * (bar_length - filled_length) elapsed_time = time.time() - self.start_time M = 1024 * 1024 # noqa speed = ( (self.downloaded_size - self.start_size) / (elapsed_time * M) if elapsed_time > 0 else 0 ) print( f"\rProgress: |{bar}| {percent:.2f}% " f"({self.downloaded_size // M}/{self.total_size // M} MB) " f"Speed: {speed:.2f} MiB/s", end="", flush=True, ) def has_disk_space(self, file_size: int) -> bool: dir_path = os.path.dirname(os.path.abspath(self.output_file)) free_space = shutil.disk_usage(dir_path).free return free_space > file_size