mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
add --max-parallel option
This commit is contained in:
parent
a98dca12a9
commit
3b61c31dab
1 changed files with 20 additions and 6 deletions
|
@ -73,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
|||
required=False,
|
||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
required=False,
|
||||
default=3,
|
||||
help="Maximum number of concurrent downloads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-patterns",
|
||||
type=str,
|
||||
|
@ -381,7 +388,12 @@ def _hf_download(
|
|||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
|
||||
def _meta_download(
|
||||
model: "Model",
|
||||
meta_url: str,
|
||||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
):
|
||||
"""Download model files from Meta using parallel downloader"""
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
@ -401,7 +413,7 @@ def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
|
|||
)
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=3)
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {output_dir}")
|
||||
|
@ -421,7 +433,7 @@ class Manifest(BaseModel):
|
|||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str):
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
"""Download files from manifest using parallel downloader"""
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
@ -464,7 +476,9 @@ def _download_from_manifest(manifest_file: str):
|
|||
]
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=3)
|
||||
downloader = ParallelDownloader(
|
||||
max_concurrent_downloads=max_concurrent_downloads
|
||||
)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
|
||||
|
@ -472,7 +486,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
"""Main download command handler"""
|
||||
try:
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file)
|
||||
_download_from_manifest(args.manifest_file, args.max_parallel)
|
||||
return
|
||||
|
||||
if args.model_id is None:
|
||||
|
@ -511,7 +525,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
)
|
||||
if "llamameta.net" not in meta_url:
|
||||
parser.error("Invalid Meta URL provided")
|
||||
_meta_download(model, meta_url, info)
|
||||
_meta_download(model, meta_url, info, args.max_parallel)
|
||||
|
||||
except Exception as e:
|
||||
parser.error(f"Download failed: {str(e)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue