Add a verify-download command to llama CLI (#457)

# What does this PR do?

It is important to verify large checkpoints downloaded via `llama model
download` because subtle corruptions can easily happen with large file
system writes. This PR adds a `verify-download` subcommand. Note that
verification itself is a very time consuming process (and will take
several **minutes** for the 405B model), hence this is a separate
subcommand (and not part of the download which can already be
time-consuming) and there are spinners and a bit of a "show" around it
in the implementation.

## Test Plan

<img width="1012" alt="image"
src="https://github.com/user-attachments/assets/f82b0d42-2a15-4917-b85e-6d3cd7d31e55">
This commit is contained in:
Ashwin Bharambe 2024-11-14 11:47:51 -08:00 committed by GitHub
parent 0713607b68
commit acbecbf8b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 173 additions and 1 deletions

View file

@ -9,6 +9,7 @@ import argparse
from .download import Download from .download import Download
from .model import ModelParser from .model import ModelParser
from .stack import StackParser from .stack import StackParser
from .verify_download import VerifyDownload
class LlamaCLIParser: class LlamaCLIParser:
@ -27,9 +28,10 @@ class LlamaCLIParser:
subparsers = self.parser.add_subparsers(title="subcommands") subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands # Add sub-commands
Download.create(subparsers)
ModelParser.create(subparsers) ModelParser.create(subparsers)
StackParser.create(subparsers) StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() return self.parser.parse_args()

View file

@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
@ -32,3 +33,4 @@ class ModelParser(Subcommand):
ModelList.create(subparsers) ModelList.create(subparsers)
ModelPromptFormat.create(subparsers) ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers) ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelVerifyDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama model verify-download",
description="Verify the downloaded checkpoints' checksums",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.verify_download import setup_verify_download_parser
setup_verify_download_parser(self.parser)

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.cli.subcommand import Subcommand
@dataclass
class VerificationResult:
filename: str
expected_hash: str
actual_hash: Optional[str]
exists: bool
matches: bool
class VerifyDownload(Subcommand):
"""Llama cli for verifying downloaded model files"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama verify-download",
description="Verify integrity of downloaded model files",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_verify_download_parser(self.parser)
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
required=True,
help="Model ID to verify",
)
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
md5_hash = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()
def load_checksums(checklist_path: Path) -> Dict[str, str]:
checksums = {}
with open(checklist_path, "r") as f:
for line in f:
if line.strip():
md5sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = md5sum
return checksums
def verify_files(
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
results = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists()
actual_hash = None
matches = False
if exists:
actual_hash = calculate_md5(full_path)
matches = actual_hash == expected_hash
results.append(
VerificationResult(
filename=filepath,
expected_hash=expected_hash,
actual_hash=actual_hash,
exists=exists,
matches=matches,
)
)
progress.remove_task(task_id)
return results
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.distribution.utils.model_utils import model_local_dir
console = Console()
model_dir = Path(model_local_dir(args.model_id))
checklist_path = model_dir / "checklist.chk"
if not model_dir.exists():
parser.error(f"Model directory not found: {model_dir}")
if not checklist_path.exists():
parser.error(f"Checklist file not found: {checklist_path}")
checksums = load_checksums(checklist_path)
results = verify_files(model_dir, checksums, console)
# Print results
console.print("\nVerification Results:")
all_good = True
for result in results:
if not result.exists:
console.print(f"[red]❌ {result.filename}: File not found[/red]")
all_good = False
elif not result.matches:
console.print(
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
f" Expected: {result.expected_hash}\n"
f" Got: {result.actual_hash}"
)
all_good = False
else:
console.print(f"[green]✓ {result.filename}: Verified[/green]")
if all_good:
console.print("\n[green]All files verified successfully![/green]")