mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Add a verify-download command to llama CLI (#457)
# What does this PR do? It is important to verify large checkpoints downloaded via `llama model download` because subtle corruptions can easily happen with large file system writes. This PR adds a `verify-download` subcommand. Note that verification itself is a very time consuming process (and will take several **minutes** for the 405B model), hence this is a separate subcommand (and not part of the download which can already be time-consuming) and there are spinners and a bit of a "show" around it in the implementation. ## Test Plan <img width="1012" alt="image" src="https://github.com/user-attachments/assets/f82b0d42-2a15-4917-b85e-6d3cd7d31e55">
This commit is contained in:
parent
0713607b68
commit
acbecbf8b3
4 changed files with 173 additions and 1 deletions
|
@ -9,6 +9,7 @@ import argparse
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
from .stack import StackParser
|
from .stack import StackParser
|
||||||
|
from .verify_download import VerifyDownload
|
||||||
|
|
||||||
|
|
||||||
class LlamaCLIParser:
|
class LlamaCLIParser:
|
||||||
|
@ -27,9 +28,10 @@ class LlamaCLIParser:
|
||||||
subparsers = self.parser.add_subparsers(title="subcommands")
|
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
Download.create(subparsers)
|
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
StackParser.create(subparsers)
|
StackParser.create(subparsers)
|
||||||
|
Download.create(subparsers)
|
||||||
|
VerifyDownload.create(subparsers)
|
||||||
|
|
||||||
def parse_args(self) -> argparse.Namespace:
|
def parse_args(self) -> argparse.Namespace:
|
||||||
return self.parser.parse_args()
|
return self.parser.parse_args()
|
||||||
|
|
|
@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
|
||||||
from llama_stack.cli.model.download import ModelDownload
|
from llama_stack.cli.model.download import ModelDownload
|
||||||
from llama_stack.cli.model.list import ModelList
|
from llama_stack.cli.model.list import ModelList
|
||||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||||
|
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
@ -32,3 +33,4 @@ class ModelParser(Subcommand):
|
||||||
ModelList.create(subparsers)
|
ModelList.create(subparsers)
|
||||||
ModelPromptFormat.create(subparsers)
|
ModelPromptFormat.create(subparsers)
|
||||||
ModelDescribe.create(subparsers)
|
ModelDescribe.create(subparsers)
|
||||||
|
ModelVerifyDownload.create(subparsers)
|
||||||
|
|
24
llama_stack/cli/model/verify_download.py
Normal file
24
llama_stack/cli/model/verify_download.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVerifyDownload(Subcommand):
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"verify-download",
|
||||||
|
prog="llama model verify-download",
|
||||||
|
description="Verify the downloaded checkpoints' checksums",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.cli.verify_download import setup_verify_download_parser
|
||||||
|
|
||||||
|
setup_verify_download_parser(self.parser)
|
144
llama_stack/cli/verify_download.py
Normal file
144
llama_stack/cli/verify_download.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VerificationResult:
|
||||||
|
filename: str
|
||||||
|
expected_hash: str
|
||||||
|
actual_hash: Optional[str]
|
||||||
|
exists: bool
|
||||||
|
matches: bool
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyDownload(Subcommand):
|
||||||
|
"""Llama cli for verifying downloaded model files"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"verify-download",
|
||||||
|
prog="llama verify-download",
|
||||||
|
description="Verify integrity of downloaded model files",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
setup_verify_download_parser(self.parser)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-id",
|
||||||
|
required=True,
|
||||||
|
help="Model ID to verify",
|
||||||
|
)
|
||||||
|
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||||
|
md5_hash = hashlib.md5()
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||||
|
md5_hash.update(chunk)
|
||||||
|
return md5_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||||
|
checksums = {}
|
||||||
|
with open(checklist_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
md5sum, filepath = line.strip().split(" ", 1)
|
||||||
|
# Remove leading './' if present
|
||||||
|
filepath = filepath.lstrip("./")
|
||||||
|
checksums[filepath] = md5sum
|
||||||
|
return checksums
|
||||||
|
|
||||||
|
|
||||||
|
def verify_files(
|
||||||
|
model_dir: Path, checksums: Dict[str, str], console: Console
|
||||||
|
) -> List[VerificationResult]:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
console=console,
|
||||||
|
) as progress:
|
||||||
|
for filepath, expected_hash in checksums.items():
|
||||||
|
full_path = model_dir / filepath
|
||||||
|
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||||
|
|
||||||
|
exists = full_path.exists()
|
||||||
|
actual_hash = None
|
||||||
|
matches = False
|
||||||
|
|
||||||
|
if exists:
|
||||||
|
actual_hash = calculate_md5(full_path)
|
||||||
|
matches = actual_hash == expected_hash
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
VerificationResult(
|
||||||
|
filename=filepath,
|
||||||
|
expected_hash=expected_hash,
|
||||||
|
actual_hash=actual_hash,
|
||||||
|
exists=exists,
|
||||||
|
matches=matches,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
progress.remove_task(task_id)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
model_dir = Path(model_local_dir(args.model_id))
|
||||||
|
checklist_path = model_dir / "checklist.chk"
|
||||||
|
|
||||||
|
if not model_dir.exists():
|
||||||
|
parser.error(f"Model directory not found: {model_dir}")
|
||||||
|
|
||||||
|
if not checklist_path.exists():
|
||||||
|
parser.error(f"Checklist file not found: {checklist_path}")
|
||||||
|
|
||||||
|
checksums = load_checksums(checklist_path)
|
||||||
|
results = verify_files(model_dir, checksums, console)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
console.print("\nVerification Results:")
|
||||||
|
|
||||||
|
all_good = True
|
||||||
|
for result in results:
|
||||||
|
if not result.exists:
|
||||||
|
console.print(f"[red]❌ {result.filename}: File not found[/red]")
|
||||||
|
all_good = False
|
||||||
|
elif not result.matches:
|
||||||
|
console.print(
|
||||||
|
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
|
||||||
|
f" Expected: {result.expected_hash}\n"
|
||||||
|
f" Got: {result.actual_hash}"
|
||||||
|
)
|
||||||
|
all_good = False
|
||||||
|
else:
|
||||||
|
console.print(f"[green]✓ {result.filename}: Verified[/green]")
|
||||||
|
|
||||||
|
if all_good:
|
||||||
|
console.print("\n[green]All files verified successfully![/green]")
|
Loading…
Add table
Add a link
Reference in a new issue