diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 4a0f88aaa..07b40bd21 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -9,15 +9,27 @@ import asyncio import json import os import shutil -import time +from dataclasses import dataclass from datetime import datetime from functools import partial from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional import httpx + +from llama_models.datatypes import Model +from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel +from rich.console import Console +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) from termcolor import cprint from llama_stack.cli.subcommand import Subcommand @@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: required=False, help="For source=meta, URL obtained from llama.meta.com after accepting license terms", ) + parser.add_argument( + "--max-parallel", + type=int, + required=False, + default=3, + help="Maximum number of concurrent downloads", + ) parser.add_argument( "--ignore-patterns", type=str, @@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights. parser.set_defaults(func=partial(run_download_cmd, parser=parser)) +@dataclass +class DownloadTask: + url: str + output_file: str + total_size: int = 0 + downloaded_size: int = 0 + task_id: Optional[int] = None + retries: int = 0 + max_retries: int = 3 + + +class DownloadError(Exception): + pass + + +class CustomTransferSpeedColumn(TransferSpeedColumn): + def render(self, task): + if task.finished: + return "-" + return super().render(task) + + +class ParallelDownloader: + def __init__( + self, + max_concurrent_downloads: int = 3, + buffer_size: int = 1024 * 1024, + timeout: int = 30, + ): + self.max_concurrent_downloads = max_concurrent_downloads + self.buffer_size = buffer_size + self.timeout = timeout + self.console = Console() + self.progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=40), + "[progress.percentage]{task.percentage:>3.1f}%", + DownloadColumn(), + CustomTransferSpeedColumn(), + TimeRemainingColumn(), + console=self.console, + expand=True, + ) + self.client_options = { + "timeout": httpx.Timeout(timeout), + "follow_redirects": True, + } + + async def retry_with_exponential_backoff( + self, task: DownloadTask, func, *args, **kwargs + ): + last_exception = None + for attempt in range(task.max_retries): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + if attempt < task.max_retries - 1: + wait_time = min(30, 2**attempt) # Cap at 30 seconds + self.console.print( + f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, " + f"retrying in {wait_time} seconds: {str(e)}[/yellow]" + ) + await asyncio.sleep(wait_time) + continue + raise last_exception + + async def get_file_info( + self, client: httpx.AsyncClient, task: DownloadTask + ) -> None: + async def _get_info(): + response = await client.head( + task.url, headers={"Accept-Encoding": "identity"}, **self.client_options + ) + response.raise_for_status() + return response + + try: + response = await self.retry_with_exponential_backoff(task, _get_info) + + task.url = str(response.url) + task.total_size = int(response.headers.get("Content-Length", 0)) + + if task.total_size == 0: + raise DownloadError( + f"Unable to determine file size for {task.output_file}. " + "The server might not support range requests." + ) + + # Update the progress bar's total size once we know it + if task.task_id is not None: + self.progress.update(task.task_id, total=task.total_size) + + except httpx.HTTPError as e: + self.console.print(f"[red]Error getting file info: {str(e)}[/red]") + raise + + def verify_file_integrity(self, task: DownloadTask) -> bool: + if not os.path.exists(task.output_file): + return False + return os.path.getsize(task.output_file) == task.total_size + + async def download_chunk( + self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int + ) -> None: + async def _download_chunk(): + headers = {"Range": f"bytes={start}-{end}"} + async with client.stream( + "GET", task.url, headers=headers, **self.client_options + ) as response: + response.raise_for_status() + + with open(task.output_file, "ab") as file: + file.seek(start) + async for chunk in response.aiter_bytes(self.buffer_size): + file.write(chunk) + task.downloaded_size += len(chunk) + self.progress.update( + task.task_id, + completed=task.downloaded_size, + ) + + try: + await self.retry_with_exponential_backoff(task, _download_chunk) + except Exception as e: + raise DownloadError( + f"Failed to download chunk {start}-{end} after " + f"{task.max_retries} attempts: {str(e)}" + ) from e + + async def prepare_download(self, task: DownloadTask) -> None: + output_dir = os.path.dirname(task.output_file) + os.makedirs(output_dir, exist_ok=True) + + if os.path.exists(task.output_file): + task.downloaded_size = os.path.getsize(task.output_file) + + async def download_file(self, task: DownloadTask) -> None: + try: + async with httpx.AsyncClient(**self.client_options) as client: + await self.get_file_info(client, task) + + # Check if file is already downloaded + if os.path.exists(task.output_file): + if self.verify_file_integrity(task): + self.console.print( + f"[green]Already downloaded {task.output_file}[/green]" + ) + self.progress.update(task.task_id, completed=task.total_size) + return + + await self.prepare_download(task) + + try: + # Split the remaining download into chunks + chunk_size = 27_000_000_000 # Cloudfront max chunk size + chunks = [] + + current_pos = task.downloaded_size + while current_pos < task.total_size: + chunk_end = min( + current_pos + chunk_size - 1, task.total_size - 1 + ) + chunks.append((current_pos, chunk_end)) + current_pos = chunk_end + 1 + + # Download chunks in sequence + for chunk_start, chunk_end in chunks: + await self.download_chunk(client, task, chunk_start, chunk_end) + + except Exception as e: + raise DownloadError(f"Download failed: {str(e)}") from e + + except Exception as e: + self.progress.update( + task.task_id, description=f"[red]Failed: {task.output_file}[/red]" + ) + raise DownloadError( + f"Download failed for {task.output_file}: {str(e)}" + ) from e + + def has_disk_space(self, tasks: List[DownloadTask]) -> bool: + try: + total_remaining_size = sum( + task.total_size - task.downloaded_size for task in tasks + ) + dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file)) + free_space = shutil.disk_usage(dir_path).free + + # Add 10% buffer for safety + required_space = int(total_remaining_size * 1.1) + + if free_space < required_space: + self.console.print( + f"[red]Not enough disk space. Required: {required_space // (1024*1024)} MB, " + f"Available: {free_space // (1024*1024)} MB[/red]" + ) + return False + return True + + except Exception as e: + raise DownloadError(f"Failed to check disk space: {str(e)}") from e + + async def download_all(self, tasks: List[DownloadTask]) -> None: + if not tasks: + raise ValueError("No download tasks provided") + + if not self.has_disk_space(tasks): + raise DownloadError("Insufficient disk space for downloads") + + failed_tasks = [] + + with self.progress: + for task in tasks: + desc = f"Downloading {Path(task.output_file).name}" + task.task_id = self.progress.add_task( + desc, total=task.total_size, completed=task.downloaded_size + ) + + semaphore = asyncio.Semaphore(self.max_concurrent_downloads) + + async def download_with_semaphore(task: DownloadTask): + async with semaphore: + try: + await self.download_file(task) + except Exception as e: + failed_tasks.append((task, str(e))) + + await asyncio.gather(*(download_with_semaphore(task) for task in tasks)) + + if failed_tasks: + self.console.print("\n[red]Some downloads failed:[/red]") + for task, error in failed_tasks: + self.console.print( + f"[red]- {Path(task.output_file).name}: {error}[/red]" + ) + raise DownloadError(f"{len(failed_tasks)} downloads failed") + + def _hf_download( model: "Model", hf_token: str, @@ -120,63 +378,37 @@ def _hf_download( print(f"\nSuccessfully downloaded model to {true_output_dir}") -def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"): +def _meta_download( + model: "Model", + meta_url: str, + info: "LlamaDownloadInfo", + max_concurrent_downloads: int, +): from llama_stack.distribution.utils.model_utils import model_local_dir output_dir = Path(model_local_dir(model.descriptor())) os.makedirs(output_dir, exist_ok=True) - # I believe we can use some concurrency here if needed but not sure it is worth it + # Create download tasks for each file + tasks = [] for f in info.files: output_file = str(output_dir / f) url = meta_url.replace("*", f"{info.folder}/{f}") total_size = info.pth_size if "consolidated" in f else 0 - cprint(f"Downloading `{f}`...", "white") - downloader = ResumableDownloader(url, output_file, total_size) - asyncio.run(downloader.download()) + tasks.append( + DownloadTask( + url=url, output_file=output_file, total_size=total_size, max_retries=3 + ) + ) + + # Initialize and run parallel downloader + downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) + asyncio.run(downloader.download_all(tasks)) print(f"\nSuccessfully downloaded model to {output_dir}") cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white") -def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): - from llama_models.sku_list import llama_meta_net_info, resolve_model - - from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku - - if args.manifest_file: - _download_from_manifest(args.manifest_file) - return - - if args.model_id is None: - parser.error("Please provide a model id") - return - - # Check if model_id is a comma-separated list - model_ids = [model_id.strip() for model_id in args.model_id.split(",")] - - prompt_guard = prompt_guard_model_sku() - for model_id in model_ids: - if model_id == prompt_guard.model_id: - model = prompt_guard - info = prompt_guard_download_info() - else: - model = resolve_model(model_id) - if model is None: - parser.error(f"Model {model_id} not found") - continue - info = llama_meta_net_info(model) - - if args.source == "huggingface": - _hf_download(model, args.hf_token, args.ignore_patterns, parser) - else: - meta_url = args.meta_url or input( - f"Please provide the signed URL for model {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): " - ) - assert "llamameta.net" in meta_url - _meta_download(model, meta_url, info) - - class ModelEntry(BaseModel): model_id: str files: Dict[str, str] @@ -190,7 +422,7 @@ class Manifest(BaseModel): expires_on: datetime -def _download_from_manifest(manifest_file: str): +def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int): from llama_stack.distribution.utils.model_utils import model_local_dir with open(manifest_file, "r") as f: @@ -200,143 +432,88 @@ def _download_from_manifest(manifest_file: str): if datetime.now() > manifest.expires_on: raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") + console = Console() for entry in manifest.models: - print(f"Downloading model {entry.model_id}...") + console.print(f"[blue]Downloading model {entry.model_id}...[/blue]") output_dir = Path(model_local_dir(entry.model_id)) os.makedirs(output_dir, exist_ok=True) if any(output_dir.iterdir()): - cprint(f"Output directory {output_dir} is not empty.", "red") + console.print( + f"[yellow]Output directory {output_dir} is not empty.[/yellow]" + ) while True: resp = input( "Do you want to (C)ontinue download or (R)estart completely? (continue/restart): " ) - if resp.lower() == "restart" or resp.lower() == "r": + if resp.lower() in ["restart", "r"]: shutil.rmtree(output_dir) os.makedirs(output_dir, exist_ok=True) break - elif resp.lower() == "continue" or resp.lower() == "c": - print("Continuing download...") + elif resp.lower() in ["continue", "c"]: + console.print("[blue]Continuing download...[/blue]") break else: - cprint("Invalid response. Please try again.", "red") + console.print("[red]Invalid response. Please try again.[/red]") - for fname, url in entry.files.items(): - output_file = str(output_dir / fname) - downloader = ResumableDownloader(url, output_file) - asyncio.run(downloader.download()) + # Create download tasks for all files in the manifest + tasks = [ + DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3) + for fname, url in entry.files.items() + ] + + # Initialize and run parallel downloader + downloader = ParallelDownloader( + max_concurrent_downloads=max_concurrent_downloads + ) + asyncio.run(downloader.download_all(tasks)) -class ResumableDownloader: - def __init__( - self, - url: str, - output_file: str, - total_size: int = 0, - buffer_size: int = 32 * 1024, - ): - self.url = url - self.output_file = output_file - self.buffer_size = buffer_size - self.total_size = total_size - self.downloaded_size = 0 - self.start_size = 0 - self.start_time = 0 - - async def get_file_info(self, client: httpx.AsyncClient) -> None: - if self.total_size > 0: +def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): + """Main download command handler""" + try: + if args.manifest_file: + _download_from_manifest(args.manifest_file, args.max_parallel) return - # Force disable compression when trying to retrieve file size - response = await client.head( - self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"} - ) - response.raise_for_status() - self.url = str(response.url) # Update URL in case of redirects - self.total_size = int(response.headers.get("Content-Length", 0)) - if self.total_size == 0: - raise ValueError( - "Unable to determine file size. The server might not support range requests." - ) + if args.model_id is None: + parser.error("Please provide a model id") + return - async def download(self) -> None: - self.start_time = time.time() - async with httpx.AsyncClient(follow_redirects=True) as client: - await self.get_file_info(client) + # Handle comma-separated model IDs + model_ids = [model_id.strip() for model_id in args.model_id.split(",")] - if os.path.exists(self.output_file): - self.downloaded_size = os.path.getsize(self.output_file) - self.start_size = self.downloaded_size - if self.downloaded_size >= self.total_size: - print(f"Already downloaded `{self.output_file}`, skipping...") - return + from llama_models.sku_list import llama_meta_net_info, resolve_model - additional_size = self.total_size - self.downloaded_size - if not self.has_disk_space(additional_size): - M = 1024 * 1024 # noqa - print( - f"Not enough disk space to download `{self.output_file}`. " - f"Required: {(additional_size // M):.2f} MB" - ) - raise ValueError( - f"Not enough disk space to download `{self.output_file}`" - ) - - while True: - if self.downloaded_size >= self.total_size: - break - - # Cloudfront has a max-size limit - max_chunk_size = 27_000_000_000 - request_size = min( - self.total_size - self.downloaded_size, max_chunk_size - ) - headers = { - "Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}" - } - print(f"Downloading `{self.output_file}`....{headers}") - try: - async with client.stream( - "GET", self.url, headers=headers - ) as response: - response.raise_for_status() - with open(self.output_file, "ab") as file: - async for chunk in response.aiter_bytes(self.buffer_size): - file.write(chunk) - self.downloaded_size += len(chunk) - self.print_progress() - except httpx.HTTPError as e: - print(f"\nDownload interrupted: {e}") - print("You can resume the download by running the script again.") - except Exception as e: - print(f"\nAn error occurred: {e}") - - print(f"\nFinished downloading `{self.output_file}`....") - - def print_progress(self) -> None: - percent = (self.downloaded_size / self.total_size) * 100 - bar_length = 50 - filled_length = int(bar_length * self.downloaded_size // self.total_size) - bar = "█" * filled_length + "-" * (bar_length - filled_length) - - elapsed_time = time.time() - self.start_time - M = 1024 * 1024 # noqa - - speed = ( - (self.downloaded_size - self.start_size) / (elapsed_time * M) - if elapsed_time > 0 - else 0 - ) - print( - f"\rProgress: |{bar}| {percent:.2f}% " - f"({self.downloaded_size // M}/{self.total_size // M} MB) " - f"Speed: {speed:.2f} MiB/s", - end="", - flush=True, + from .model.safety_models import ( + prompt_guard_download_info, + prompt_guard_model_sku, ) - def has_disk_space(self, file_size: int) -> bool: - dir_path = os.path.dirname(os.path.abspath(self.output_file)) - free_space = shutil.disk_usage(dir_path).free - return free_space > file_size + prompt_guard = prompt_guard_model_sku() + for model_id in model_ids: + if model_id == prompt_guard.model_id: + model = prompt_guard + info = prompt_guard_download_info() + else: + model = resolve_model(model_id) + if model is None: + parser.error(f"Model {model_id} not found") + continue + info = llama_meta_net_info(model) + + if args.source == "huggingface": + _hf_download(model, args.hf_token, args.ignore_patterns, parser) + else: + meta_url = args.meta_url or input( + f"Please provide the signed URL for model {model_id} you received via email " + f"after visiting https://www.llama.com/llama-downloads/ " + f"(e.g., https://llama3-1.llamameta.net/*?Policy...): " + ) + if "llamameta.net" not in meta_url: + parser.error("Invalid Meta URL provided") + _meta_download(model, meta_url, info, args.max_parallel) + + except Exception as e: + parser.error(f"Download failed: {str(e)}")