forked from phoenix-oss/llama-stack-mirror
Add a verify-download command to llama CLI (#457)
# What does this PR do? It is important to verify large checkpoints downloaded via `llama model download` because subtle corruptions can easily happen with large file system writes. This PR adds a `verify-download` subcommand. Note that verification itself is a very time consuming process (and will take several **minutes** for the 405B model), hence this is a separate subcommand (and not part of the download which can already be time-consuming) and there are spinners and a bit of a "show" around it in the implementation. ## Test Plan <img width="1012" alt="image" src="https://github.com/user-attachments/assets/f82b0d42-2a15-4917-b85e-6d3cd7d31e55">
This commit is contained in:
parent
0713607b68
commit
acbecbf8b3
4 changed files with 173 additions and 1 deletions
144
llama_stack/cli/verify_download.py
Normal file
144
llama_stack/cli/verify_download.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
filename: str
|
||||
expected_hash: str
|
||||
actual_hash: Optional[str]
|
||||
exists: bool
|
||||
matches: bool
|
||||
|
||||
|
||||
class VerifyDownload(Subcommand):
|
||||
"""Llama cli for verifying downloaded model files"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama verify-download",
|
||||
description="Verify integrity of downloaded model files",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=True,
|
||||
help="Model ID to verify",
|
||||
)
|
||||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||
|
||||
|
||||
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
md5_hash = hashlib.md5()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5_hash.update(chunk)
|
||||
return md5_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||
checksums = {}
|
||||
with open(checklist_path, "r") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
md5sum, filepath = line.strip().split(" ", 1)
|
||||
# Remove leading './' if present
|
||||
filepath = filepath.lstrip("./")
|
||||
checksums[filepath] = md5sum
|
||||
return checksums
|
||||
|
||||
|
||||
def verify_files(
|
||||
model_dir: Path, checksums: Dict[str, str], console: Console
|
||||
) -> List[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=console,
|
||||
) as progress:
|
||||
for filepath, expected_hash in checksums.items():
|
||||
full_path = model_dir / filepath
|
||||
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
|
||||
exists = full_path.exists()
|
||||
actual_hash = None
|
||||
matches = False
|
||||
|
||||
if exists:
|
||||
actual_hash = calculate_md5(full_path)
|
||||
matches = actual_hash == expected_hash
|
||||
|
||||
results.append(
|
||||
VerificationResult(
|
||||
filename=filepath,
|
||||
expected_hash=expected_hash,
|
||||
actual_hash=actual_hash,
|
||||
exists=exists,
|
||||
matches=matches,
|
||||
)
|
||||
)
|
||||
|
||||
progress.remove_task(task_id)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
console = Console()
|
||||
model_dir = Path(model_local_dir(args.model_id))
|
||||
checklist_path = model_dir / "checklist.chk"
|
||||
|
||||
if not model_dir.exists():
|
||||
parser.error(f"Model directory not found: {model_dir}")
|
||||
|
||||
if not checklist_path.exists():
|
||||
parser.error(f"Checklist file not found: {checklist_path}")
|
||||
|
||||
checksums = load_checksums(checklist_path)
|
||||
results = verify_files(model_dir, checksums, console)
|
||||
|
||||
# Print results
|
||||
console.print("\nVerification Results:")
|
||||
|
||||
all_good = True
|
||||
for result in results:
|
||||
if not result.exists:
|
||||
console.print(f"[red]❌ {result.filename}: File not found[/red]")
|
||||
all_good = False
|
||||
elif not result.matches:
|
||||
console.print(
|
||||
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
|
||||
f" Expected: {result.expected_hash}\n"
|
||||
f" Got: {result.actual_hash}"
|
||||
)
|
||||
all_good = False
|
||||
else:
|
||||
console.print(f"[green]✓ {result.filename}: Verified[/green]")
|
||||
|
||||
if all_good:
|
||||
console.print("\n[green]All files verified successfully![/green]")
|
Loading…
Add table
Add a link
Reference in a new issue