add --max-parallel option

This commit is contained in:
Ashwin Bharambe 2024-11-13 14:03:17 -08:00
parent a98dca12a9
commit 3b61c31dab

View file

@ -73,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
required=False, required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms", help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
) )
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument( parser.add_argument(
"--ignore-patterns", "--ignore-patterns",
type=str, type=str,
@ -381,7 +388,12 @@ def _hf_download(
print(f"\nSuccessfully downloaded model to {true_output_dir}") print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"): def _meta_download(
model: "Model",
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
"""Download model files from Meta using parallel downloader""" """Download model files from Meta using parallel downloader"""
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
@ -401,7 +413,7 @@ def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
) )
# Initialize and run parallel downloader # Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=3) downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks)) asyncio.run(downloader.download_all(tasks))
print(f"\nSuccessfully downloaded model to {output_dir}") print(f"\nSuccessfully downloaded model to {output_dir}")
@ -421,7 +433,7 @@ class Manifest(BaseModel):
expires_on: datetime expires_on: datetime
def _download_from_manifest(manifest_file: str): def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
"""Download files from manifest using parallel downloader""" """Download files from manifest using parallel downloader"""
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
@ -464,7 +476,9 @@ def _download_from_manifest(manifest_file: str):
] ]
# Initialize and run parallel downloader # Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=3) downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
asyncio.run(downloader.download_all(tasks)) asyncio.run(downloader.download_all(tasks))
@ -472,7 +486,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Main download command handler""" """Main download command handler"""
try: try:
if args.manifest_file: if args.manifest_file:
_download_from_manifest(args.manifest_file) _download_from_manifest(args.manifest_file, args.max_parallel)
return return
if args.model_id is None: if args.model_id is None:
@ -511,7 +525,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
) )
if "llamameta.net" not in meta_url: if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided") parser.error("Invalid Meta URL provided")
_meta_download(model, meta_url, info) _meta_download(model, meta_url, info, args.max_parallel)
except Exception as e: except Exception as e:
parser.error(f"Download failed: {str(e)}") parser.error(f"Download failed: {str(e)}")