From 6608c7fed900a0e13e10dd23728a2545f1629132 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Thu, 6 Feb 2025 19:48:50 -0500 Subject: [PATCH] refactor: support downloading any model from HF given the work being done to support non-llama models, the download utility should be able to take any `hf_repo/model` to download a qualified model from HF. While the model might not be able to be used quite yet in llama stack directly, its helpful to have a utility that can download any and all models Signed-off-by: Charlie Doern --- llama_stack/cli/download.py | 42 +++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index af86f7243..eb9e99511 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -311,25 +311,33 @@ class ParallelDownloader: def _hf_download( - model: "Model", hf_token: str, ignore_patterns: str, parser: argparse.ArgumentParser, + llama_model: Optional["Model"] | None = None, + hf_repo: Optional[str] | None = None, ): from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from llama_stack.distribution.utils.model_utils import model_local_dir - repo_id = model.huggingface_repo - if repo_id is None: - raise ValueError(f"No repo id found for model {model.descriptor()}") + # if we have a llama_model, meaning this model was found in the + # llama_models impl, replace the hf_repo (None) with this value. + if llama_model: + hf_repo = llama_model.huggingface_repo + # if the user did not pass a valid hf_repo or llama_model, error. + if hf_repo is None: + raise ValueError(f"No repo id found for model {llama_model.descriptor()}") - output_dir = model_local_dir(model.descriptor()) + if llama_model: + output_dir = model_local_dir(llama_model.descriptor()) + else: + output_dir = model_local_dir(hf_repo) os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( - repo_id, + hf_repo, local_dir=output_dir, ignore_patterns=ignore_patterns, token=hf_token, @@ -343,7 +351,7 @@ def _hf_download( "You can find your token by visiting https://huggingface.co/settings/tokens" ) except RepositoryNotFoundError: - parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.") + parser.error(f"Repository '{hf_repo}' not found on the Hugging Face Hub.") except Exception as e: parser.error(e) @@ -461,19 +469,31 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): ) prompt_guard = prompt_guard_model_sku() + + # for each model given, get the llama_models.Model + # alternatively, if that search turns up nothing and the provider is HF, pass repo directly to hf_download for model_id in model_ids: + hf_repo = None if model_id == prompt_guard.model_id: model = prompt_guard info = prompt_guard_download_info() else: model = resolve_model(model_id) if model is None: - parser.error(f"Model {model_id} not found") - continue - info = llama_meta_net_info(model) + if args.source == "huggingface": + # try just passing in as an HF repo + # if this is the case, set ignore_patterns to none + # many smaller models you want the .saftensor files. + args.ignore_patterns = None + hf_repo = model_id + else: + parser.error(f"Model {model_id} not found") + continue + else: + info = llama_meta_net_info(model) if args.source == "huggingface": - _hf_download(model, args.hf_token, args.ignore_patterns, parser) + _hf_download(args.hf_token, args.ignore_patterns, parser, llama_model=model, hf_repo=hf_repo) else: meta_url = args.meta_url or input( f"Please provide the signed URL for model {model_id} you received via email "