forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? Adds description at the end of successful download the optionally run the verify md5 checksums command. ## Test Plan <img width="2004" alt="Screenshot 2024-11-19 at 12 11 37 PM" src="https://github.com/user-attachments/assets/8d617aef-99f5-4c3b-b93c-eff3e68289ea"> ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --------- Co-authored-by: varunfb <vontimitta@devgpu004.eag5.facebook.com>
		
			
				
	
	
		
			526 lines
		
	
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			526 lines
		
	
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import argparse
 | |
| import asyncio
 | |
| import json
 | |
| import os
 | |
| import shutil
 | |
| from dataclasses import dataclass
 | |
| from datetime import datetime
 | |
| from functools import partial
 | |
| from pathlib import Path
 | |
| from typing import Dict, List, Optional
 | |
| 
 | |
| import httpx
 | |
| 
 | |
| from llama_models.datatypes import Model
 | |
| from llama_models.sku_list import LlamaDownloadInfo
 | |
| from pydantic import BaseModel, ConfigDict
 | |
| 
 | |
| from rich.console import Console
 | |
| from rich.progress import (
 | |
|     BarColumn,
 | |
|     DownloadColumn,
 | |
|     Progress,
 | |
|     TextColumn,
 | |
|     TimeRemainingColumn,
 | |
|     TransferSpeedColumn,
 | |
| )
 | |
| from termcolor import cprint
 | |
| 
 | |
| from llama_stack.cli.subcommand import Subcommand
 | |
| 
 | |
| 
 | |
| class Download(Subcommand):
 | |
|     """Llama cli for downloading llama toolchain assets"""
 | |
| 
 | |
|     def __init__(self, subparsers: argparse._SubParsersAction):
 | |
|         super().__init__()
 | |
|         self.parser = subparsers.add_parser(
 | |
|             "download",
 | |
|             prog="llama download",
 | |
|             description="Download a model from llama.meta.com or Hugging Face Hub",
 | |
|             formatter_class=argparse.RawTextHelpFormatter,
 | |
|         )
 | |
|         setup_download_parser(self.parser)
 | |
| 
 | |
| 
 | |
| def setup_download_parser(parser: argparse.ArgumentParser) -> None:
 | |
|     parser.add_argument(
 | |
|         "--source",
 | |
|         choices=["meta", "huggingface"],
 | |
|         default="meta",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--model-id",
 | |
|         required=False,
 | |
|         help="See `llama model list` or `llama model list --show-all` for the list of available models",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--hf-token",
 | |
|         type=str,
 | |
|         required=False,
 | |
|         default=None,
 | |
|         help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--meta-url",
 | |
|         type=str,
 | |
|         required=False,
 | |
|         help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--max-parallel",
 | |
|         type=int,
 | |
|         required=False,
 | |
|         default=3,
 | |
|         help="Maximum number of concurrent downloads",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--ignore-patterns",
 | |
|         type=str,
 | |
|         required=False,
 | |
|         default="*.safetensors",
 | |
|         help="""
 | |
| For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
 | |
| safetensors files to avoid downloading duplicate weights.
 | |
| """,
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--manifest-file",
 | |
|         type=str,
 | |
|         help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
 | |
|         required=False,
 | |
|     )
 | |
|     parser.set_defaults(func=partial(run_download_cmd, parser=parser))
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class DownloadTask:
 | |
|     url: str
 | |
|     output_file: str
 | |
|     total_size: int = 0
 | |
|     downloaded_size: int = 0
 | |
|     task_id: Optional[int] = None
 | |
|     retries: int = 0
 | |
|     max_retries: int = 3
 | |
| 
 | |
| 
 | |
| class DownloadError(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class CustomTransferSpeedColumn(TransferSpeedColumn):
 | |
|     def render(self, task):
 | |
|         if task.finished:
 | |
|             return "-"
 | |
|         return super().render(task)
 | |
| 
 | |
| 
 | |
| class ParallelDownloader:
 | |
|     def __init__(
 | |
|         self,
 | |
|         max_concurrent_downloads: int = 3,
 | |
|         buffer_size: int = 1024 * 1024,
 | |
|         timeout: int = 30,
 | |
|     ):
 | |
|         self.max_concurrent_downloads = max_concurrent_downloads
 | |
|         self.buffer_size = buffer_size
 | |
|         self.timeout = timeout
 | |
|         self.console = Console()
 | |
|         self.progress = Progress(
 | |
|             TextColumn("[bold blue]{task.description}"),
 | |
|             BarColumn(bar_width=40),
 | |
|             "[progress.percentage]{task.percentage:>3.1f}%",
 | |
|             DownloadColumn(),
 | |
|             CustomTransferSpeedColumn(),
 | |
|             TimeRemainingColumn(),
 | |
|             console=self.console,
 | |
|             expand=True,
 | |
|         )
 | |
|         self.client_options = {
 | |
|             "timeout": httpx.Timeout(timeout),
 | |
|             "follow_redirects": True,
 | |
|         }
 | |
| 
 | |
|     async def retry_with_exponential_backoff(
 | |
|         self, task: DownloadTask, func, *args, **kwargs
 | |
|     ):
 | |
|         last_exception = None
 | |
|         for attempt in range(task.max_retries):
 | |
|             try:
 | |
|                 return await func(*args, **kwargs)
 | |
|             except Exception as e:
 | |
|                 last_exception = e
 | |
|                 if attempt < task.max_retries - 1:
 | |
|                     wait_time = min(30, 2**attempt)  # Cap at 30 seconds
 | |
|                     self.console.print(
 | |
|                         f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
 | |
|                         f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
 | |
|                     )
 | |
|                     await asyncio.sleep(wait_time)
 | |
|                     continue
 | |
|         raise last_exception
 | |
| 
 | |
|     async def get_file_info(
 | |
|         self, client: httpx.AsyncClient, task: DownloadTask
 | |
|     ) -> None:
 | |
|         async def _get_info():
 | |
|             response = await client.head(
 | |
|                 task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
 | |
|             )
 | |
|             response.raise_for_status()
 | |
|             return response
 | |
| 
 | |
|         try:
 | |
|             response = await self.retry_with_exponential_backoff(task, _get_info)
 | |
| 
 | |
|             task.url = str(response.url)
 | |
|             task.total_size = int(response.headers.get("Content-Length", 0))
 | |
| 
 | |
|             if task.total_size == 0:
 | |
|                 raise DownloadError(
 | |
|                     f"Unable to determine file size for {task.output_file}. "
 | |
|                     "The server might not support range requests."
 | |
|                 )
 | |
| 
 | |
|             # Update the progress bar's total size once we know it
 | |
|             if task.task_id is not None:
 | |
|                 self.progress.update(task.task_id, total=task.total_size)
 | |
| 
 | |
|         except httpx.HTTPError as e:
 | |
|             self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
 | |
|             raise
 | |
| 
 | |
|     def verify_file_integrity(self, task: DownloadTask) -> bool:
 | |
|         if not os.path.exists(task.output_file):
 | |
|             return False
 | |
|         return os.path.getsize(task.output_file) == task.total_size
 | |
| 
 | |
|     async def download_chunk(
 | |
|         self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
 | |
|     ) -> None:
 | |
|         async def _download_chunk():
 | |
|             headers = {"Range": f"bytes={start}-{end}"}
 | |
|             async with client.stream(
 | |
|                 "GET", task.url, headers=headers, **self.client_options
 | |
|             ) as response:
 | |
|                 response.raise_for_status()
 | |
| 
 | |
|                 with open(task.output_file, "ab") as file:
 | |
|                     file.seek(start)
 | |
|                     async for chunk in response.aiter_bytes(self.buffer_size):
 | |
|                         file.write(chunk)
 | |
|                         task.downloaded_size += len(chunk)
 | |
|                         self.progress.update(
 | |
|                             task.task_id,
 | |
|                             completed=task.downloaded_size,
 | |
|                         )
 | |
| 
 | |
|         try:
 | |
|             await self.retry_with_exponential_backoff(task, _download_chunk)
 | |
|         except Exception as e:
 | |
|             raise DownloadError(
 | |
|                 f"Failed to download chunk {start}-{end} after "
 | |
|                 f"{task.max_retries} attempts: {str(e)}"
 | |
|             ) from e
 | |
| 
 | |
|     async def prepare_download(self, task: DownloadTask) -> None:
 | |
|         output_dir = os.path.dirname(task.output_file)
 | |
|         os.makedirs(output_dir, exist_ok=True)
 | |
| 
 | |
|         if os.path.exists(task.output_file):
 | |
|             task.downloaded_size = os.path.getsize(task.output_file)
 | |
| 
 | |
|     async def download_file(self, task: DownloadTask) -> None:
 | |
|         try:
 | |
|             async with httpx.AsyncClient(**self.client_options) as client:
 | |
|                 await self.get_file_info(client, task)
 | |
| 
 | |
|                 # Check if file is already downloaded
 | |
|                 if os.path.exists(task.output_file):
 | |
|                     if self.verify_file_integrity(task):
 | |
|                         self.console.print(
 | |
|                             f"[green]Already downloaded {task.output_file}[/green]"
 | |
|                         )
 | |
|                         self.progress.update(task.task_id, completed=task.total_size)
 | |
|                         return
 | |
| 
 | |
|                 await self.prepare_download(task)
 | |
| 
 | |
|                 try:
 | |
|                     # Split the remaining download into chunks
 | |
|                     chunk_size = 27_000_000_000  # Cloudfront max chunk size
 | |
|                     chunks = []
 | |
| 
 | |
|                     current_pos = task.downloaded_size
 | |
|                     while current_pos < task.total_size:
 | |
|                         chunk_end = min(
 | |
|                             current_pos + chunk_size - 1, task.total_size - 1
 | |
|                         )
 | |
|                         chunks.append((current_pos, chunk_end))
 | |
|                         current_pos = chunk_end + 1
 | |
| 
 | |
|                     # Download chunks in sequence
 | |
|                     for chunk_start, chunk_end in chunks:
 | |
|                         await self.download_chunk(client, task, chunk_start, chunk_end)
 | |
| 
 | |
|                 except Exception as e:
 | |
|                     raise DownloadError(f"Download failed: {str(e)}") from e
 | |
| 
 | |
|         except Exception as e:
 | |
|             self.progress.update(
 | |
|                 task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
 | |
|             )
 | |
|             raise DownloadError(
 | |
|                 f"Download failed for {task.output_file}: {str(e)}"
 | |
|             ) from e
 | |
| 
 | |
|     def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
 | |
|         try:
 | |
|             total_remaining_size = sum(
 | |
|                 task.total_size - task.downloaded_size for task in tasks
 | |
|             )
 | |
|             dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
 | |
|             free_space = shutil.disk_usage(dir_path).free
 | |
| 
 | |
|             # Add 10% buffer for safety
 | |
|             required_space = int(total_remaining_size * 1.1)
 | |
| 
 | |
|             if free_space < required_space:
 | |
|                 self.console.print(
 | |
|                     f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
 | |
|                     f"Available: {free_space // (1024 * 1024)} MB[/red]"
 | |
|                 )
 | |
|                 return False
 | |
|             return True
 | |
| 
 | |
|         except Exception as e:
 | |
|             raise DownloadError(f"Failed to check disk space: {str(e)}") from e
 | |
| 
 | |
|     async def download_all(self, tasks: List[DownloadTask]) -> None:
 | |
|         if not tasks:
 | |
|             raise ValueError("No download tasks provided")
 | |
| 
 | |
|         if not self.has_disk_space(tasks):
 | |
|             raise DownloadError("Insufficient disk space for downloads")
 | |
| 
 | |
|         failed_tasks = []
 | |
| 
 | |
|         with self.progress:
 | |
|             for task in tasks:
 | |
|                 desc = f"Downloading {Path(task.output_file).name}"
 | |
|                 task.task_id = self.progress.add_task(
 | |
|                     desc, total=task.total_size, completed=task.downloaded_size
 | |
|                 )
 | |
| 
 | |
|             semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
 | |
| 
 | |
|             async def download_with_semaphore(task: DownloadTask):
 | |
|                 async with semaphore:
 | |
|                     try:
 | |
|                         await self.download_file(task)
 | |
|                     except Exception as e:
 | |
|                         failed_tasks.append((task, str(e)))
 | |
| 
 | |
|             await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
 | |
| 
 | |
|         if failed_tasks:
 | |
|             self.console.print("\n[red]Some downloads failed:[/red]")
 | |
|             for task, error in failed_tasks:
 | |
|                 self.console.print(
 | |
|                     f"[red]- {Path(task.output_file).name}: {error}[/red]"
 | |
|                 )
 | |
|             raise DownloadError(f"{len(failed_tasks)} downloads failed")
 | |
| 
 | |
| 
 | |
| def _hf_download(
 | |
|     model: "Model",
 | |
|     hf_token: str,
 | |
|     ignore_patterns: str,
 | |
|     parser: argparse.ArgumentParser,
 | |
| ):
 | |
|     from huggingface_hub import snapshot_download
 | |
|     from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
 | |
| 
 | |
|     from llama_stack.distribution.utils.model_utils import model_local_dir
 | |
| 
 | |
|     repo_id = model.huggingface_repo
 | |
|     if repo_id is None:
 | |
|         raise ValueError(f"No repo id found for model {model.descriptor()}")
 | |
| 
 | |
|     output_dir = model_local_dir(model.descriptor())
 | |
|     os.makedirs(output_dir, exist_ok=True)
 | |
|     try:
 | |
|         true_output_dir = snapshot_download(
 | |
|             repo_id,
 | |
|             local_dir=output_dir,
 | |
|             ignore_patterns=ignore_patterns,
 | |
|             token=hf_token,
 | |
|             library_name="llama-stack",
 | |
|         )
 | |
|     except GatedRepoError:
 | |
|         parser.error(
 | |
|             "It looks like you are trying to access a gated repository. Please ensure you "
 | |
|             "have access to the repository and have provided the proper Hugging Face API token "
 | |
|             "using the option `--hf-token` or by running `huggingface-cli login`."
 | |
|             "You can find your token by visiting https://huggingface.co/settings/tokens"
 | |
|         )
 | |
|     except RepositoryNotFoundError:
 | |
|         parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
 | |
|     except Exception as e:
 | |
|         parser.error(e)
 | |
| 
 | |
|     print(f"\nSuccessfully downloaded model to {true_output_dir}")
 | |
| 
 | |
| 
 | |
| def _meta_download(
 | |
|     model: "Model",
 | |
|     model_id: str,
 | |
|     meta_url: str,
 | |
|     info: "LlamaDownloadInfo",
 | |
|     max_concurrent_downloads: int,
 | |
| ):
 | |
|     from llama_stack.distribution.utils.model_utils import model_local_dir
 | |
| 
 | |
|     output_dir = Path(model_local_dir(model.descriptor()))
 | |
|     os.makedirs(output_dir, exist_ok=True)
 | |
| 
 | |
|     # Create download tasks for each file
 | |
|     tasks = []
 | |
|     for f in info.files:
 | |
|         output_file = str(output_dir / f)
 | |
|         url = meta_url.replace("*", f"{info.folder}/{f}")
 | |
|         total_size = info.pth_size if "consolidated" in f else 0
 | |
|         tasks.append(
 | |
|             DownloadTask(
 | |
|                 url=url, output_file=output_file, total_size=total_size, max_retries=3
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     # Initialize and run parallel downloader
 | |
|     downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
 | |
|     asyncio.run(downloader.download_all(tasks))
 | |
| 
 | |
|     cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
 | |
|     cprint(
 | |
|         f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
 | |
|         "white",
 | |
|     )
 | |
|     cprint(
 | |
|         f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
 | |
|         "yellow",
 | |
|     )
 | |
| 
 | |
| 
 | |
| class ModelEntry(BaseModel):
 | |
|     model_id: str
 | |
|     files: Dict[str, str]
 | |
| 
 | |
|     model_config = ConfigDict(protected_namespaces=())
 | |
| 
 | |
| 
 | |
| class Manifest(BaseModel):
 | |
|     models: List[ModelEntry]
 | |
|     expires_on: datetime
 | |
| 
 | |
| 
 | |
| def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
 | |
|     from llama_stack.distribution.utils.model_utils import model_local_dir
 | |
| 
 | |
|     with open(manifest_file, "r") as f:
 | |
|         d = json.load(f)
 | |
|         manifest = Manifest(**d)
 | |
| 
 | |
|     if datetime.now() > manifest.expires_on:
 | |
|         raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
 | |
| 
 | |
|     console = Console()
 | |
|     for entry in manifest.models:
 | |
|         console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
 | |
|         output_dir = Path(model_local_dir(entry.model_id))
 | |
|         os.makedirs(output_dir, exist_ok=True)
 | |
| 
 | |
|         if any(output_dir.iterdir()):
 | |
|             console.print(
 | |
|                 f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
 | |
|             )
 | |
| 
 | |
|             while True:
 | |
|                 resp = input(
 | |
|                     "Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
 | |
|                 )
 | |
|                 if resp.lower() in ["restart", "r"]:
 | |
|                     shutil.rmtree(output_dir)
 | |
|                     os.makedirs(output_dir, exist_ok=True)
 | |
|                     break
 | |
|                 elif resp.lower() in ["continue", "c"]:
 | |
|                     console.print("[blue]Continuing download...[/blue]")
 | |
|                     break
 | |
|                 else:
 | |
|                     console.print("[red]Invalid response. Please try again.[/red]")
 | |
| 
 | |
|         # Create download tasks for all files in the manifest
 | |
|         tasks = [
 | |
|             DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
 | |
|             for fname, url in entry.files.items()
 | |
|         ]
 | |
| 
 | |
|         # Initialize and run parallel downloader
 | |
|         downloader = ParallelDownloader(
 | |
|             max_concurrent_downloads=max_concurrent_downloads
 | |
|         )
 | |
|         asyncio.run(downloader.download_all(tasks))
 | |
| 
 | |
| 
 | |
| def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
 | |
|     """Main download command handler"""
 | |
|     try:
 | |
|         if args.manifest_file:
 | |
|             _download_from_manifest(args.manifest_file, args.max_parallel)
 | |
|             return
 | |
| 
 | |
|         if args.model_id is None:
 | |
|             parser.error("Please provide a model id")
 | |
|             return
 | |
| 
 | |
|         # Handle comma-separated model IDs
 | |
|         model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
 | |
| 
 | |
|         from llama_models.sku_list import llama_meta_net_info, resolve_model
 | |
| 
 | |
|         from .model.safety_models import (
 | |
|             prompt_guard_download_info,
 | |
|             prompt_guard_model_sku,
 | |
|         )
 | |
| 
 | |
|         prompt_guard = prompt_guard_model_sku()
 | |
|         for model_id in model_ids:
 | |
|             if model_id == prompt_guard.model_id:
 | |
|                 model = prompt_guard
 | |
|                 info = prompt_guard_download_info()
 | |
|             else:
 | |
|                 model = resolve_model(model_id)
 | |
|                 if model is None:
 | |
|                     parser.error(f"Model {model_id} not found")
 | |
|                     continue
 | |
|                 info = llama_meta_net_info(model)
 | |
| 
 | |
|             if args.source == "huggingface":
 | |
|                 _hf_download(model, args.hf_token, args.ignore_patterns, parser)
 | |
|             else:
 | |
|                 meta_url = args.meta_url or input(
 | |
|                     f"Please provide the signed URL for model {model_id} you received via email "
 | |
|                     f"after visiting https://www.llama.com/llama-downloads/ "
 | |
|                     f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
 | |
|                 )
 | |
|                 if "llamameta.net" not in meta_url:
 | |
|                     parser.error("Invalid Meta URL provided")
 | |
|                 _meta_download(model, model_id, meta_url, info, args.max_parallel)
 | |
| 
 | |
|     except Exception as e:
 | |
|         parser.error(f"Download failed: {str(e)}")
 |