From 3b61c31dabbd47c7a9af1e793339b8e97cf03b69 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 13 Nov 2024 14:03:17 -0800 Subject: [PATCH] add --max-parallel option --- llama_stack/cli/download.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 1cd4c8820..f23531f90 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -73,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: required=False, help="For source=meta, URL obtained from llama.meta.com after accepting license terms", ) + parser.add_argument( + "--max-parallel", + type=int, + required=False, + default=3, + help="Maximum number of concurrent downloads", + ) parser.add_argument( "--ignore-patterns", type=str, @@ -381,7 +388,12 @@ def _hf_download( print(f"\nSuccessfully downloaded model to {true_output_dir}") -def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"): +def _meta_download( + model: "Model", + meta_url: str, + info: "LlamaDownloadInfo", + max_concurrent_downloads: int, +): """Download model files from Meta using parallel downloader""" from llama_stack.distribution.utils.model_utils import model_local_dir @@ -401,7 +413,7 @@ def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"): ) # Initialize and run parallel downloader - downloader = ParallelDownloader(max_concurrent_downloads=3) + downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) asyncio.run(downloader.download_all(tasks)) print(f"\nSuccessfully downloaded model to {output_dir}") @@ -421,7 +433,7 @@ class Manifest(BaseModel): expires_on: datetime -def _download_from_manifest(manifest_file: str): +def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int): """Download files from manifest using parallel downloader""" from llama_stack.distribution.utils.model_utils import model_local_dir @@ -464,7 +476,9 @@ def _download_from_manifest(manifest_file: str): ] # Initialize and run parallel downloader - downloader = ParallelDownloader(max_concurrent_downloads=3) + downloader = ParallelDownloader( + max_concurrent_downloads=max_concurrent_downloads + ) asyncio.run(downloader.download_all(tasks)) @@ -472,7 +486,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): """Main download command handler""" try: if args.manifest_file: - _download_from_manifest(args.manifest_file) + _download_from_manifest(args.manifest_file, args.max_parallel) return if args.model_id is None: @@ -511,7 +525,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): ) if "llamameta.net" not in meta_url: parser.error("Invalid Meta URL provided") - _meta_download(model, meta_url, info) + _meta_download(model, meta_url, info, args.max_parallel) except Exception as e: parser.error(f"Download failed: {str(e)}")