diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index a1495cbf0..7363d6b07 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -48,6 +48,11 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: required=False, help="See `llama model list` or `llama model list --show-all` for the list of available models", ) + parser.add_argument( + "--model-ids", + required=False, + help="Comma-separated list of model IDs to download. See `llama model list` or `llama model list --show-all` for the list of available models", + ) parser.add_argument( "--hf-token", type=str, @@ -148,31 +153,43 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): _download_from_manifest(args.manifest_file) return - if args.model_id is None: - parser.error("Please provide a model id") + if args.model_ids: + model_ids = [model_id.strip() for model_id in args.model_ids.split(",")] + elif args.model_id: + model_ids = [args.model_id] + else: + parser.error("Please provide a model id or a list of model ids (--model-ids)") return - prompt_guard = prompt_guard_model_sku() - if args.model_id == prompt_guard.model_id: - model = prompt_guard - info = prompt_guard_download_info() - else: - model = resolve_model(args.model_id) - if model is None: - parser.error(f"Model {args.model_id} not found") - return - info = llama_meta_net_info(model) + meta_urls = [] + if args.meta_url: + meta_urls = [url.strip() for url in args.meta_url.split(",")] + if len(meta_urls) > 0 and len(meta_urls) != len(model_ids): + parser.error("The number of --meta-url values must match the number of --model-ids values.") - if args.source == "huggingface": - _hf_download(model, args.hf_token, args.ignore_patterns, parser) - else: - meta_url = args.meta_url - if not meta_url: - meta_url = input( - "Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): " - ) + prompt_guard = prompt_guard_model_sku() + for idx, model_id in enumerate(model_ids): + if model_id == prompt_guard.model_id: + model = prompt_guard + info = prompt_guard_download_info() + else: + model = resolve_model(model_id) + if model is None: + parser.error(f"Model {model_id} not found") + return + info = llama_meta_net_info(model) + + if args.source == "huggingface": + _hf_download(model, args.hf_token, args.ignore_patterns, parser) + else: + if len(meta_urls) > idx: + meta_url = meta_urls[idx] + else: + meta_url = input( + f"Please provide the signed URL for {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): " + ) assert meta_url is not None and "llamameta.net" in meta_url - _meta_download(model, meta_url, info) + _meta_download(model, meta_url, info) class ModelEntry(BaseModel):