Support parallel downloads for llama model download (#448)

# What does this PR do?

Enables parallel downloads for `llama model download` CLI command. It is
rather necessary for folks having high bandwidth connections to the
Internet in order to download checkpoints quickly.

## Test Plan


![image](https://github.com/user-attachments/assets/f5df69e2-ec4f-4360-bf84-91273d8cee22)
This commit is contained in:
Ashwin Bharambe 2024-11-14 09:56:22 -08:00 committed by GitHub
parent 0c750102c6
commit 0713607b68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,15 +9,27 @@ import asyncio
import json
import os
import shutil
import time
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional
import httpx
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument(
"--ignore-patterns",
type=str,
@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights.
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
@dataclass
class DownloadTask:
url: str
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: Optional[int] = None
retries: int = 0
max_retries: int = 3
class DownloadError(Exception):
pass
class CustomTransferSpeedColumn(TransferSpeedColumn):
def render(self, task):
if task.finished:
return "-"
return super().render(task)
class ParallelDownloader:
def __init__(
self,
max_concurrent_downloads: int = 3,
buffer_size: int = 1024 * 1024,
timeout: int = 30,
):
self.max_concurrent_downloads = max_concurrent_downloads
self.buffer_size = buffer_size
self.timeout = timeout
self.console = Console()
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=40),
"[progress.percentage]{task.percentage:>3.1f}%",
DownloadColumn(),
CustomTransferSpeedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
self.client_options = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
}
async def retry_with_exponential_backoff(
self, task: DownloadTask, func, *args, **kwargs
):
last_exception = None
for attempt in range(task.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < task.max_retries - 1:
wait_time = min(30, 2**attempt) # Cap at 30 seconds
self.console.print(
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
)
await asyncio.sleep(wait_time)
continue
raise last_exception
async def get_file_info(
self, client: httpx.AsyncClient, task: DownloadTask
) -> None:
async def _get_info():
response = await client.head(
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
)
response.raise_for_status()
return response
try:
response = await self.retry_with_exponential_backoff(task, _get_info)
task.url = str(response.url)
task.total_size = int(response.headers.get("Content-Length", 0))
if task.total_size == 0:
raise DownloadError(
f"Unable to determine file size for {task.output_file}. "
"The server might not support range requests."
)
# Update the progress bar's total size once we know it
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
raise
def verify_file_integrity(self, task: DownloadTask) -> bool:
if not os.path.exists(task.output_file):
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream(
"GET", task.url, headers=headers, **self.client_options
) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
file.seek(start)
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
try:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after "
f"{task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
output_dir = os.path.dirname(task.output_file)
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(task.output_file):
task.downloaded_size = os.path.getsize(task.output_file)
async def download_file(self, task: DownloadTask) -> None:
try:
async with httpx.AsyncClient(**self.client_options) as client:
await self.get_file_info(client, task)
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(
f"[green]Already downloaded {task.output_file}[/green]"
)
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
try:
# Split the remaining download into chunks
chunk_size = 27_000_000_000 # Cloudfront max chunk size
chunks = []
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(
current_pos + chunk_size - 1, task.total_size - 1
)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
# Download chunks in sequence
for chunk_start, chunk_end in chunks:
await self.download_chunk(client, task, chunk_start, chunk_end)
except Exception as e:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
)
raise DownloadError(
f"Download failed for {task.output_file}: {str(e)}"
) from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
try:
total_remaining_size = sum(
task.total_size - task.downloaded_size for task in tasks
)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
# Add 10% buffer for safety
required_space = int(total_remaining_size * 1.1)
if free_space < required_space:
self.console.print(
f"[red]Not enough disk space. Required: {required_space // (1024*1024)} MB, "
f"Available: {free_space // (1024*1024)} MB[/red]"
)
return False
return True
except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: List[DownloadTask]) -> None:
if not tasks:
raise ValueError("No download tasks provided")
if not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(
desc, total=task.total_size, completed=task.downloaded_size
)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
async def download_with_semaphore(task: DownloadTask):
async with semaphore:
try:
await self.download_file(task)
except Exception as e:
failed_tasks.append((task, str(e)))
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(
f"[red]- {Path(task.output_file).name}: {error}[/red]"
)
raise DownloadError(f"{len(failed_tasks)} downloads failed")
def _hf_download(
model: "Model",
hf_token: str,
@ -120,63 +378,37 @@ def _hf_download(
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
def _meta_download(
model: "Model",
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
# I believe we can use some concurrency here if needed but not sure it is worth it
# Create download tasks for each file
tasks = []
for f in info.files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
cprint(f"Downloading `{f}`...", "white")
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download())
tasks.append(
DownloadTask(
url=url, output_file=output_file, total_size=total_size, max_retries=3
)
)
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
# Check if model_id is a comma-separated list
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
prompt_guard = prompt_guard_model_sku()
for model_id in model_ids:
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert "llamameta.net" in meta_url
_meta_download(model, meta_url, info)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
@ -190,7 +422,7 @@ class Manifest(BaseModel):
expires_on: datetime
def _download_from_manifest(manifest_file: str):
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f:
@ -200,143 +432,88 @@ def _download_from_manifest(manifest_file: str):
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
console.print(
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
)
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
elif resp.lower() in ["continue", "c"]:
console.print("[blue]Continuing download...[/blue]")
break
else:
cprint("Invalid response. Please try again.", "red")
console.print("[red]Invalid response. Please try again.[/red]")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
# Create download tasks for all files in the manifest
tasks = [
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
for fname, url in entry.files.items()
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
asyncio.run(downloader.download_all(tasks))
class ResumableDownloader:
def __init__(
self,
url: str,
output_file: str,
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url
self.output_file = output_file
self.buffer_size = buffer_size
self.total_size = total_size
self.downloaded_size = 0
self.start_size = 0
self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Main download command handler"""
try:
if args.manifest_file:
_download_from_manifest(args.manifest_file, args.max_parallel)
return
# Force disable compression when trying to retrieve file size
response = await client.head(
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
)
response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects
self.total_size = int(response.headers.get("Content-Length", 0))
if self.total_size == 0:
raise ValueError(
"Unable to determine file size. The server might not support range requests."
)
if args.model_id is None:
parser.error("Please provide a model id")
return
async def download(self) -> None:
self.start_time = time.time()
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
# Handle comma-separated model IDs
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
if os.path.exists(self.output_file):
self.downloaded_size = os.path.getsize(self.output_file)
self.start_size = self.downloaded_size
if self.downloaded_size >= self.total_size:
print(f"Already downloaded `{self.output_file}`, skipping...")
return
from llama_models.sku_list import llama_meta_net_info, resolve_model
additional_size = self.total_size - self.downloaded_size
if not self.has_disk_space(additional_size):
M = 1024 * 1024 # noqa
print(
f"Not enough disk space to download `{self.output_file}`. "
f"Required: {(additional_size // M):.2f} MB"
)
raise ValueError(
f"Not enough disk space to download `{self.output_file}`"
)
while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None:
percent = (self.downloaded_size / self.total_size) * 100
bar_length = 50
filled_length = int(bar_length * self.downloaded_size // self.total_size)
bar = "" * filled_length + "-" * (bar_length - filled_length)
elapsed_time = time.time() - self.start_time
M = 1024 * 1024 # noqa
speed = (
(self.downloaded_size - self.start_size) / (elapsed_time * M)
if elapsed_time > 0
else 0
)
print(
f"\rProgress: |{bar}| {percent:.2f}% "
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
f"Speed: {speed:.2f} MiB/s",
end="",
flush=True,
from .model.safety_models import (
prompt_guard_download_info,
prompt_guard_model_sku,
)
def has_disk_space(self, file_size: int) -> bool:
dir_path = os.path.dirname(os.path.abspath(self.output_file))
free_space = shutil.disk_usage(dir_path).free
return free_space > file_size
prompt_guard = prompt_guard_model_sku()
for model_id in model_ids:
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email "
f"after visiting https://www.llama.com/llama-downloads/ "
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")