Added optional md5 validate command once download is completed

This commit is contained in:
varunfb 2024-11-19 12:05:34 -08:00 committed by varunfb
parent 394519d68a
commit 42acff502c

View file

@ -19,6 +19,8 @@ import httpx
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from llama_stack.cli.subcommand import Subcommand
from pydantic import BaseModel, ConfigDict
from rich.console import Console
@ -32,8 +34,6 @@ from rich.progress import (
)
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
@ -380,6 +380,7 @@ def _hf_download(
def _meta_download(
model: "Model",
model_id: str,
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
@ -405,8 +406,15 @@ def _meta_download(
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(
f"\nView MD5 checksum files at following location: {output_dir / 'checklist.chk'}",
"white",
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
)
class ModelEntry(BaseModel):
@ -512,7 +520,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, meta_url, info, args.max_parallel)
_meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")