add --max-parallel option

This commit is contained in:
Ashwin Bharambe 2024-11-13 14:03:17 -08:00
parent a98dca12a9
commit 3b61c31dab

View file

@ -73,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument(
"--ignore-patterns",
type=str,
@ -381,7 +388,12 @@ def _hf_download(
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
def _meta_download(
model: "Model",
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
"""Download model files from Meta using parallel downloader"""
from llama_stack.distribution.utils.model_utils import model_local_dir
@ -401,7 +413,7 @@ def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
)
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=3)
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
print(f"\nSuccessfully downloaded model to {output_dir}")
@ -421,7 +433,7 @@ class Manifest(BaseModel):
expires_on: datetime
def _download_from_manifest(manifest_file: str):
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
"""Download files from manifest using parallel downloader"""
from llama_stack.distribution.utils.model_utils import model_local_dir
@ -464,7 +476,9 @@ def _download_from_manifest(manifest_file: str):
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=3)
downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
asyncio.run(downloader.download_all(tasks))
@ -472,7 +486,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Main download command handler"""
try:
if args.manifest_file:
_download_from_manifest(args.manifest_file)
_download_from_manifest(args.manifest_file, args.max_parallel)
return
if args.model_id is None:
@ -511,7 +525,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, meta_url, info)
_meta_download(model, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")