Added optional md5 validate command once download is completed

This commit is contained in:
varunfb 2024-11-19 12:05:34 -08:00 committed by varunfb
parent 394519d68a
commit 42acff502c

View file

@ -19,6 +19,8 @@ import httpx
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from llama_stack.cli.subcommand import Subcommand
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from rich.console import Console from rich.console import Console
@ -32,8 +34,6 @@ from rich.progress import (
) )
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
class Download(Subcommand): class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets""" """Llama cli for downloading llama toolchain assets"""
@ -380,6 +380,7 @@ def _hf_download(
def _meta_download( def _meta_download(
model: "Model", model: "Model",
model_id: str,
meta_url: str, meta_url: str,
info: "LlamaDownloadInfo", info: "LlamaDownloadInfo",
max_concurrent_downloads: int, max_concurrent_downloads: int,
@ -405,8 +406,15 @@ def _meta_download(
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks)) asyncio.run(downloader.download_all(tasks))
print(f"\nSuccessfully downloaded model to {output_dir}") cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white") cprint(
f"\nView MD5 checksum files at following location: {output_dir / 'checklist.chk'}",
"white",
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
)
class ModelEntry(BaseModel): class ModelEntry(BaseModel):
@ -512,7 +520,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
) )
if "llamameta.net" not in meta_url: if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided") parser.error("Invalid Meta URL provided")
_meta_download(model, meta_url, info, args.max_parallel) _meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e: except Exception as e:
parser.error(f"Download failed: {str(e)}") parser.error(f"Download failed: {str(e)}")