mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Support parallel downloads for llama model download
(#448)
# What does this PR do? Enables parallel downloads for `llama model download` CLI command. It is rather necessary for folks having high bandwidth connections to the Internet in order to download checkpoints quickly. ## Test Plan 
This commit is contained in:
parent
0c750102c6
commit
0713607b68
1 changed files with 338 additions and 161 deletions
|
@ -9,15 +9,27 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from llama_models.datatypes import Model
|
||||||
|
from llama_models.sku_list import LlamaDownloadInfo
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import (
|
||||||
|
BarColumn,
|
||||||
|
DownloadColumn,
|
||||||
|
Progress,
|
||||||
|
TextColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
TransferSpeedColumn,
|
||||||
|
)
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
required=False,
|
required=False,
|
||||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-parallel",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=3,
|
||||||
|
help="Maximum number of concurrent downloads",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ignore-patterns",
|
"--ignore-patterns",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadTask:
|
||||||
|
url: str
|
||||||
|
output_file: str
|
||||||
|
total_size: int = 0
|
||||||
|
downloaded_size: int = 0
|
||||||
|
task_id: Optional[int] = None
|
||||||
|
retries: int = 0
|
||||||
|
max_retries: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTransferSpeedColumn(TransferSpeedColumn):
|
||||||
|
def render(self, task):
|
||||||
|
if task.finished:
|
||||||
|
return "-"
|
||||||
|
return super().render(task)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelDownloader:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_concurrent_downloads: int = 3,
|
||||||
|
buffer_size: int = 1024 * 1024,
|
||||||
|
timeout: int = 30,
|
||||||
|
):
|
||||||
|
self.max_concurrent_downloads = max_concurrent_downloads
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
self.timeout = timeout
|
||||||
|
self.console = Console()
|
||||||
|
self.progress = Progress(
|
||||||
|
TextColumn("[bold blue]{task.description}"),
|
||||||
|
BarColumn(bar_width=40),
|
||||||
|
"[progress.percentage]{task.percentage:>3.1f}%",
|
||||||
|
DownloadColumn(),
|
||||||
|
CustomTransferSpeedColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=self.console,
|
||||||
|
expand=True,
|
||||||
|
)
|
||||||
|
self.client_options = {
|
||||||
|
"timeout": httpx.Timeout(timeout),
|
||||||
|
"follow_redirects": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def retry_with_exponential_backoff(
|
||||||
|
self, task: DownloadTask, func, *args, **kwargs
|
||||||
|
):
|
||||||
|
last_exception = None
|
||||||
|
for attempt in range(task.max_retries):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
if attempt < task.max_retries - 1:
|
||||||
|
wait_time = min(30, 2**attempt) # Cap at 30 seconds
|
||||||
|
self.console.print(
|
||||||
|
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
|
||||||
|
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
async def get_file_info(
|
||||||
|
self, client: httpx.AsyncClient, task: DownloadTask
|
||||||
|
) -> None:
|
||||||
|
async def _get_info():
|
||||||
|
response = await client.head(
|
||||||
|
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.retry_with_exponential_backoff(task, _get_info)
|
||||||
|
|
||||||
|
task.url = str(response.url)
|
||||||
|
task.total_size = int(response.headers.get("Content-Length", 0))
|
||||||
|
|
||||||
|
if task.total_size == 0:
|
||||||
|
raise DownloadError(
|
||||||
|
f"Unable to determine file size for {task.output_file}. "
|
||||||
|
"The server might not support range requests."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the progress bar's total size once we know it
|
||||||
|
if task.task_id is not None:
|
||||||
|
self.progress.update(task.task_id, total=task.total_size)
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def verify_file_integrity(self, task: DownloadTask) -> bool:
|
||||||
|
if not os.path.exists(task.output_file):
|
||||||
|
return False
|
||||||
|
return os.path.getsize(task.output_file) == task.total_size
|
||||||
|
|
||||||
|
async def download_chunk(
|
||||||
|
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
|
||||||
|
) -> None:
|
||||||
|
async def _download_chunk():
|
||||||
|
headers = {"Range": f"bytes={start}-{end}"}
|
||||||
|
async with client.stream(
|
||||||
|
"GET", task.url, headers=headers, **self.client_options
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(task.output_file, "ab") as file:
|
||||||
|
file.seek(start)
|
||||||
|
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||||
|
file.write(chunk)
|
||||||
|
task.downloaded_size += len(chunk)
|
||||||
|
self.progress.update(
|
||||||
|
task.task_id,
|
||||||
|
completed=task.downloaded_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||||
|
except Exception as e:
|
||||||
|
raise DownloadError(
|
||||||
|
f"Failed to download chunk {start}-{end} after "
|
||||||
|
f"{task.max_retries} attempts: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def prepare_download(self, task: DownloadTask) -> None:
|
||||||
|
output_dir = os.path.dirname(task.output_file)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if os.path.exists(task.output_file):
|
||||||
|
task.downloaded_size = os.path.getsize(task.output_file)
|
||||||
|
|
||||||
|
async def download_file(self, task: DownloadTask) -> None:
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(**self.client_options) as client:
|
||||||
|
await self.get_file_info(client, task)
|
||||||
|
|
||||||
|
# Check if file is already downloaded
|
||||||
|
if os.path.exists(task.output_file):
|
||||||
|
if self.verify_file_integrity(task):
|
||||||
|
self.console.print(
|
||||||
|
f"[green]Already downloaded {task.output_file}[/green]"
|
||||||
|
)
|
||||||
|
self.progress.update(task.task_id, completed=task.total_size)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.prepare_download(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Split the remaining download into chunks
|
||||||
|
chunk_size = 27_000_000_000 # Cloudfront max chunk size
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
current_pos = task.downloaded_size
|
||||||
|
while current_pos < task.total_size:
|
||||||
|
chunk_end = min(
|
||||||
|
current_pos + chunk_size - 1, task.total_size - 1
|
||||||
|
)
|
||||||
|
chunks.append((current_pos, chunk_end))
|
||||||
|
current_pos = chunk_end + 1
|
||||||
|
|
||||||
|
# Download chunks in sequence
|
||||||
|
for chunk_start, chunk_end in chunks:
|
||||||
|
await self.download_chunk(client, task, chunk_start, chunk_end)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.progress.update(
|
||||||
|
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
|
||||||
|
)
|
||||||
|
raise DownloadError(
|
||||||
|
f"Download failed for {task.output_file}: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||||
|
try:
|
||||||
|
total_remaining_size = sum(
|
||||||
|
task.total_size - task.downloaded_size for task in tasks
|
||||||
|
)
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||||
|
free_space = shutil.disk_usage(dir_path).free
|
||||||
|
|
||||||
|
# Add 10% buffer for safety
|
||||||
|
required_space = int(total_remaining_size * 1.1)
|
||||||
|
|
||||||
|
if free_space < required_space:
|
||||||
|
self.console.print(
|
||||||
|
f"[red]Not enough disk space. Required: {required_space // (1024*1024)} MB, "
|
||||||
|
f"Available: {free_space // (1024*1024)} MB[/red]"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||||
|
|
||||||
|
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
||||||
|
if not tasks:
|
||||||
|
raise ValueError("No download tasks provided")
|
||||||
|
|
||||||
|
if not self.has_disk_space(tasks):
|
||||||
|
raise DownloadError("Insufficient disk space for downloads")
|
||||||
|
|
||||||
|
failed_tasks = []
|
||||||
|
|
||||||
|
with self.progress:
|
||||||
|
for task in tasks:
|
||||||
|
desc = f"Downloading {Path(task.output_file).name}"
|
||||||
|
task.task_id = self.progress.add_task(
|
||||||
|
desc, total=task.total_size, completed=task.downloaded_size
|
||||||
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||||
|
|
||||||
|
async def download_with_semaphore(task: DownloadTask):
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
await self.download_file(task)
|
||||||
|
except Exception as e:
|
||||||
|
failed_tasks.append((task, str(e)))
|
||||||
|
|
||||||
|
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
|
||||||
|
|
||||||
|
if failed_tasks:
|
||||||
|
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||||
|
for task, error in failed_tasks:
|
||||||
|
self.console.print(
|
||||||
|
f"[red]- {Path(task.output_file).name}: {error}[/red]"
|
||||||
|
)
|
||||||
|
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||||
|
|
||||||
|
|
||||||
def _hf_download(
|
def _hf_download(
|
||||||
model: "Model",
|
model: "Model",
|
||||||
hf_token: str,
|
hf_token: str,
|
||||||
|
@ -120,63 +378,37 @@ def _hf_download(
|
||||||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||||
|
|
||||||
|
|
||||||
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
|
def _meta_download(
|
||||||
|
model: "Model",
|
||||||
|
meta_url: str,
|
||||||
|
info: "LlamaDownloadInfo",
|
||||||
|
max_concurrent_downloads: int,
|
||||||
|
):
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
output_dir = Path(model_local_dir(model.descriptor()))
|
output_dir = Path(model_local_dir(model.descriptor()))
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
# Create download tasks for each file
|
||||||
|
tasks = []
|
||||||
for f in info.files:
|
for f in info.files:
|
||||||
output_file = str(output_dir / f)
|
output_file = str(output_dir / f)
|
||||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||||
total_size = info.pth_size if "consolidated" in f else 0
|
total_size = info.pth_size if "consolidated" in f else 0
|
||||||
cprint(f"Downloading `{f}`...", "white")
|
tasks.append(
|
||||||
downloader = ResumableDownloader(url, output_file, total_size)
|
DownloadTask(
|
||||||
asyncio.run(downloader.download())
|
url=url, output_file=output_file, total_size=total_size, max_retries=3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize and run parallel downloader
|
||||||
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
print(f"\nSuccessfully downloaded model to {output_dir}")
|
print(f"\nSuccessfully downloaded model to {output_dir}")
|
||||||
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
|
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
|
||||||
|
|
||||||
|
|
||||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|
||||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
|
||||||
|
|
||||||
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
|
|
||||||
|
|
||||||
if args.manifest_file:
|
|
||||||
_download_from_manifest(args.manifest_file)
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.model_id is None:
|
|
||||||
parser.error("Please provide a model id")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if model_id is a comma-separated list
|
|
||||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
|
||||||
for model_id in model_ids:
|
|
||||||
if model_id == prompt_guard.model_id:
|
|
||||||
model = prompt_guard
|
|
||||||
info = prompt_guard_download_info()
|
|
||||||
else:
|
|
||||||
model = resolve_model(model_id)
|
|
||||||
if model is None:
|
|
||||||
parser.error(f"Model {model_id} not found")
|
|
||||||
continue
|
|
||||||
info = llama_meta_net_info(model)
|
|
||||||
|
|
||||||
if args.source == "huggingface":
|
|
||||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
|
||||||
else:
|
|
||||||
meta_url = args.meta_url or input(
|
|
||||||
f"Please provide the signed URL for model {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
|
||||||
)
|
|
||||||
assert "llamameta.net" in meta_url
|
|
||||||
_meta_download(model, meta_url, info)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEntry(BaseModel):
|
class ModelEntry(BaseModel):
|
||||||
model_id: str
|
model_id: str
|
||||||
files: Dict[str, str]
|
files: Dict[str, str]
|
||||||
|
@ -190,7 +422,7 @@ class Manifest(BaseModel):
|
||||||
expires_on: datetime
|
expires_on: datetime
|
||||||
|
|
||||||
|
|
||||||
def _download_from_manifest(manifest_file: str):
|
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
with open(manifest_file, "r") as f:
|
with open(manifest_file, "r") as f:
|
||||||
|
@ -200,143 +432,88 @@ def _download_from_manifest(manifest_file: str):
|
||||||
if datetime.now() > manifest.expires_on:
|
if datetime.now() > manifest.expires_on:
|
||||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||||
|
|
||||||
|
console = Console()
|
||||||
for entry in manifest.models:
|
for entry in manifest.models:
|
||||||
print(f"Downloading model {entry.model_id}...")
|
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
|
||||||
output_dir = Path(model_local_dir(entry.model_id))
|
output_dir = Path(model_local_dir(entry.model_id))
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
if any(output_dir.iterdir()):
|
if any(output_dir.iterdir()):
|
||||||
cprint(f"Output directory {output_dir} is not empty.", "red")
|
console.print(
|
||||||
|
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
resp = input(
|
resp = input(
|
||||||
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||||
)
|
)
|
||||||
if resp.lower() == "restart" or resp.lower() == "r":
|
if resp.lower() in ["restart", "r"]:
|
||||||
shutil.rmtree(output_dir)
|
shutil.rmtree(output_dir)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
break
|
break
|
||||||
elif resp.lower() == "continue" or resp.lower() == "c":
|
elif resp.lower() in ["continue", "c"]:
|
||||||
print("Continuing download...")
|
console.print("[blue]Continuing download...[/blue]")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
cprint("Invalid response. Please try again.", "red")
|
console.print("[red]Invalid response. Please try again.[/red]")
|
||||||
|
|
||||||
for fname, url in entry.files.items():
|
# Create download tasks for all files in the manifest
|
||||||
output_file = str(output_dir / fname)
|
tasks = [
|
||||||
downloader = ResumableDownloader(url, output_file)
|
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
|
||||||
asyncio.run(downloader.download())
|
for fname, url in entry.files.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize and run parallel downloader
|
||||||
|
downloader = ParallelDownloader(
|
||||||
|
max_concurrent_downloads=max_concurrent_downloads
|
||||||
|
)
|
||||||
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
|
|
||||||
class ResumableDownloader:
|
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
def __init__(
|
"""Main download command handler"""
|
||||||
self,
|
try:
|
||||||
url: str,
|
if args.manifest_file:
|
||||||
output_file: str,
|
_download_from_manifest(args.manifest_file, args.max_parallel)
|
||||||
total_size: int = 0,
|
|
||||||
buffer_size: int = 32 * 1024,
|
|
||||||
):
|
|
||||||
self.url = url
|
|
||||||
self.output_file = output_file
|
|
||||||
self.buffer_size = buffer_size
|
|
||||||
self.total_size = total_size
|
|
||||||
self.downloaded_size = 0
|
|
||||||
self.start_size = 0
|
|
||||||
self.start_time = 0
|
|
||||||
|
|
||||||
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
|
||||||
if self.total_size > 0:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Force disable compression when trying to retrieve file size
|
if args.model_id is None:
|
||||||
response = await client.head(
|
parser.error("Please provide a model id")
|
||||||
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
|
return
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
self.url = str(response.url) # Update URL in case of redirects
|
|
||||||
self.total_size = int(response.headers.get("Content-Length", 0))
|
|
||||||
if self.total_size == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Unable to determine file size. The server might not support range requests."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def download(self) -> None:
|
# Handle comma-separated model IDs
|
||||||
self.start_time = time.time()
|
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
||||||
await self.get_file_info(client)
|
|
||||||
|
|
||||||
if os.path.exists(self.output_file):
|
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||||
self.downloaded_size = os.path.getsize(self.output_file)
|
|
||||||
self.start_size = self.downloaded_size
|
|
||||||
if self.downloaded_size >= self.total_size:
|
|
||||||
print(f"Already downloaded `{self.output_file}`, skipping...")
|
|
||||||
return
|
|
||||||
|
|
||||||
additional_size = self.total_size - self.downloaded_size
|
from .model.safety_models import (
|
||||||
if not self.has_disk_space(additional_size):
|
prompt_guard_download_info,
|
||||||
M = 1024 * 1024 # noqa
|
prompt_guard_model_sku,
|
||||||
print(
|
|
||||||
f"Not enough disk space to download `{self.output_file}`. "
|
|
||||||
f"Required: {(additional_size // M):.2f} MB"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough disk space to download `{self.output_file}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if self.downloaded_size >= self.total_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Cloudfront has a max-size limit
|
|
||||||
max_chunk_size = 27_000_000_000
|
|
||||||
request_size = min(
|
|
||||||
self.total_size - self.downloaded_size, max_chunk_size
|
|
||||||
)
|
|
||||||
headers = {
|
|
||||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
|
||||||
}
|
|
||||||
print(f"Downloading `{self.output_file}`....{headers}")
|
|
||||||
try:
|
|
||||||
async with client.stream(
|
|
||||||
"GET", self.url, headers=headers
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
with open(self.output_file, "ab") as file:
|
|
||||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
|
||||||
file.write(chunk)
|
|
||||||
self.downloaded_size += len(chunk)
|
|
||||||
self.print_progress()
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
print(f"\nDownload interrupted: {e}")
|
|
||||||
print("You can resume the download by running the script again.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\nAn error occurred: {e}")
|
|
||||||
|
|
||||||
print(f"\nFinished downloading `{self.output_file}`....")
|
|
||||||
|
|
||||||
def print_progress(self) -> None:
|
|
||||||
percent = (self.downloaded_size / self.total_size) * 100
|
|
||||||
bar_length = 50
|
|
||||||
filled_length = int(bar_length * self.downloaded_size // self.total_size)
|
|
||||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - self.start_time
|
|
||||||
M = 1024 * 1024 # noqa
|
|
||||||
|
|
||||||
speed = (
|
|
||||||
(self.downloaded_size - self.start_size) / (elapsed_time * M)
|
|
||||||
if elapsed_time > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"\rProgress: |{bar}| {percent:.2f}% "
|
|
||||||
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
|
|
||||||
f"Speed: {speed:.2f} MiB/s",
|
|
||||||
end="",
|
|
||||||
flush=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def has_disk_space(self, file_size: int) -> bool:
|
prompt_guard = prompt_guard_model_sku()
|
||||||
dir_path = os.path.dirname(os.path.abspath(self.output_file))
|
for model_id in model_ids:
|
||||||
free_space = shutil.disk_usage(dir_path).free
|
if model_id == prompt_guard.model_id:
|
||||||
return free_space > file_size
|
model = prompt_guard
|
||||||
|
info = prompt_guard_download_info()
|
||||||
|
else:
|
||||||
|
model = resolve_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
parser.error(f"Model {model_id} not found")
|
||||||
|
continue
|
||||||
|
info = llama_meta_net_info(model)
|
||||||
|
|
||||||
|
if args.source == "huggingface":
|
||||||
|
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||||
|
else:
|
||||||
|
meta_url = args.meta_url or input(
|
||||||
|
f"Please provide the signed URL for model {model_id} you received via email "
|
||||||
|
f"after visiting https://www.llama.com/llama-downloads/ "
|
||||||
|
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||||
|
)
|
||||||
|
if "llamameta.net" not in meta_url:
|
||||||
|
parser.error("Invalid Meta URL provided")
|
||||||
|
_meta_download(model, meta_url, info, args.max_parallel)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
parser.error(f"Download failed: {str(e)}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue