forked from phoenix-oss/llama-stack-mirror
		
	Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
		
			
				
	
	
		
			145 lines
		
	
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			145 lines
		
	
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import argparse
 | |
| import hashlib
 | |
| from dataclasses import dataclass
 | |
| from functools import partial
 | |
| from pathlib import Path
 | |
| from typing import Dict, List, Optional
 | |
| 
 | |
| from rich.console import Console
 | |
| from rich.progress import Progress, SpinnerColumn, TextColumn
 | |
| 
 | |
| from llama_stack.cli.subcommand import Subcommand
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class VerificationResult:
 | |
|     filename: str
 | |
|     expected_hash: str
 | |
|     actual_hash: Optional[str]
 | |
|     exists: bool
 | |
|     matches: bool
 | |
| 
 | |
| 
 | |
| class VerifyDownload(Subcommand):
 | |
|     """Llama cli for verifying downloaded model files"""
 | |
| 
 | |
|     def __init__(self, subparsers: argparse._SubParsersAction):
 | |
|         super().__init__()
 | |
|         self.parser = subparsers.add_parser(
 | |
|             "verify-download",
 | |
|             prog="llama verify-download",
 | |
|             description="Verify integrity of downloaded model files",
 | |
|             formatter_class=argparse.RawTextHelpFormatter,
 | |
|         )
 | |
|         setup_verify_download_parser(self.parser)
 | |
| 
 | |
| 
 | |
| def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
 | |
|     parser.add_argument(
 | |
|         "--model-id",
 | |
|         required=True,
 | |
|         help="Model ID to verify",
 | |
|     )
 | |
|     parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
 | |
| 
 | |
| 
 | |
| def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
 | |
|     # NOTE: MD5 is used here only for download integrity verification,
 | |
|     # not for security purposes
 | |
|     # TODO: switch to SHA256
 | |
|     md5_hash = hashlib.md5(usedforsecurity=False)
 | |
|     with open(filepath, "rb") as f:
 | |
|         for chunk in iter(lambda: f.read(chunk_size), b""):
 | |
|             md5_hash.update(chunk)
 | |
|     return md5_hash.hexdigest()
 | |
| 
 | |
| 
 | |
| def load_checksums(checklist_path: Path) -> Dict[str, str]:
 | |
|     checksums = {}
 | |
|     with open(checklist_path, "r") as f:
 | |
|         for line in f:
 | |
|             if line.strip():
 | |
|                 md5sum, filepath = line.strip().split("  ", 1)
 | |
|                 # Remove leading './' if present
 | |
|                 filepath = filepath.lstrip("./")
 | |
|                 checksums[filepath] = md5sum
 | |
|     return checksums
 | |
| 
 | |
| 
 | |
| def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
 | |
|     results = []
 | |
| 
 | |
|     with Progress(
 | |
|         SpinnerColumn(),
 | |
|         TextColumn("[progress.description]{task.description}"),
 | |
|         console=console,
 | |
|     ) as progress:
 | |
|         for filepath, expected_hash in checksums.items():
 | |
|             full_path = model_dir / filepath
 | |
|             task_id = progress.add_task(f"Verifying {filepath}...", total=None)
 | |
| 
 | |
|             exists = full_path.exists()
 | |
|             actual_hash = None
 | |
|             matches = False
 | |
| 
 | |
|             if exists:
 | |
|                 actual_hash = calculate_md5(full_path)
 | |
|                 matches = actual_hash == expected_hash
 | |
| 
 | |
|             results.append(
 | |
|                 VerificationResult(
 | |
|                     filename=filepath,
 | |
|                     expected_hash=expected_hash,
 | |
|                     actual_hash=actual_hash,
 | |
|                     exists=exists,
 | |
|                     matches=matches,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             progress.remove_task(task_id)
 | |
| 
 | |
|     return results
 | |
| 
 | |
| 
 | |
| def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
 | |
|     from llama_stack.distribution.utils.model_utils import model_local_dir
 | |
| 
 | |
|     console = Console()
 | |
|     model_dir = Path(model_local_dir(args.model_id))
 | |
|     checklist_path = model_dir / "checklist.chk"
 | |
| 
 | |
|     if not model_dir.exists():
 | |
|         parser.error(f"Model directory not found: {model_dir}")
 | |
| 
 | |
|     if not checklist_path.exists():
 | |
|         parser.error(f"Checklist file not found: {checklist_path}")
 | |
| 
 | |
|     checksums = load_checksums(checklist_path)
 | |
|     results = verify_files(model_dir, checksums, console)
 | |
| 
 | |
|     # Print results
 | |
|     console.print("\nVerification Results:")
 | |
| 
 | |
|     all_good = True
 | |
|     for result in results:
 | |
|         if not result.exists:
 | |
|             console.print(f"[red]❌ {result.filename}: File not found[/red]")
 | |
|             all_good = False
 | |
|         elif not result.matches:
 | |
|             console.print(
 | |
|                 f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
 | |
|                 f"   Expected: {result.expected_hash}\n"
 | |
|                 f"   Got:      {result.actual_hash}"
 | |
|             )
 | |
|             all_good = False
 | |
|         else:
 | |
|             console.print(f"[green]✓ {result.filename}: Verified[/green]")
 | |
| 
 | |
|     if all_good:
 | |
|         console.print("\n[green]All files verified successfully![/green]")
 |