Added optional md5 validate command once download is completed (#486)

# What does this PR do?

Adds description at the end of successful download the optionally run
the verify md5 checksums command.

## Test Plan
<img width="2004" alt="Screenshot 2024-11-19 at 12 11 37 PM"
src="https://github.com/user-attachments/assets/8d617aef-99f5-4c3b-b93c-eff3e68289ea">

## Before submitting

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.

---------

Co-authored-by: varunfb <vontimitta@devgpu004.eag5.facebook.com>
This commit is contained in:
varunfb 2024-11-19 17:42:43 -08:00 committed by GitHub
parent e670f99ef7
commit 08be023290
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -380,6 +380,7 @@ def _hf_download(
def _meta_download(
model: "Model",
model_id: str,
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
@ -405,8 +406,15 @@ def _meta_download(
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
"white",
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
)
class ModelEntry(BaseModel):
@ -512,7 +520,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, meta_url, info, args.max_parallel)
_meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")