Add llama download support for multiple models with comma-separated list

This commit is contained in:
ABucket 2024-10-12 14:28:36 +08:00
parent a2b87ed0cb
commit 8fcded0004

View file

@ -48,6 +48,11 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
required=False,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
parser.add_argument(
"--model-ids",
required=False,
help="Comma-separated list of model IDs to download. See `llama model list` or `llama model list --show-all` for the list of available models",
)
parser.add_argument(
"--hf-token",
type=str,
@ -148,28 +153,40 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
if args.model_ids:
model_ids = [model_id.strip() for model_id in args.model_ids.split(",")]
elif args.model_id:
model_ids = [args.model_id]
else:
parser.error("Please provide a model id or a list of model ids (--model-ids)")
return
meta_urls = []
if args.meta_url:
meta_urls = [url.strip() for url in args.meta_url.split(",")]
if len(meta_urls) > 0 and len(meta_urls) != len(model_ids):
parser.error("The number of --meta-url values must match the number of --model-ids values.")
prompt_guard = prompt_guard_model_sku()
if args.model_id == prompt_guard.model_id:
for idx, model_id in enumerate(model_ids):
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(args.model_id)
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
parser.error(f"Model {model_id} not found")
return
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
if len(meta_urls) > idx:
meta_url = meta_urls[idx]
else:
meta_url = input(
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
f"Please provide the signed URL for {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url, info)