refactor: support downloading any model from HF

given the work being done to support non-llama models,
the download utility should be able to take any `hf_repo/model` to download a qualified model from HF.
While the model might not be able to be used quite yet in llama stack directly, its helpful to have
a utility that can download any and all models

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-02-06 19:48:50 -05:00
parent 19ae4b35d9
commit 6608c7fed9

View file

@ -311,25 +311,33 @@ class ParallelDownloader:
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
llama_model: Optional["Model"] | None = None,
hf_repo: Optional[str] | None = None,
):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_stack.distribution.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
# if we have a llama_model, meaning this model was found in the
# llama_models impl, replace the hf_repo (None) with this value.
if llama_model:
hf_repo = llama_model.huggingface_repo
# if the user did not pass a valid hf_repo or llama_model, error.
if hf_repo is None:
raise ValueError(f"No repo id found for model {llama_model.descriptor()}")
output_dir = model_local_dir(model.descriptor())
if llama_model:
output_dir = model_local_dir(llama_model.descriptor())
else:
output_dir = model_local_dir(hf_repo)
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
repo_id,
hf_repo,
local_dir=output_dir,
ignore_patterns=ignore_patterns,
token=hf_token,
@ -343,7 +351,7 @@ def _hf_download(
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
parser.error(f"Repository '{hf_repo}' not found on the Hugging Face Hub.")
except Exception as e:
parser.error(e)
@ -461,19 +469,31 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
)
prompt_guard = prompt_guard_model_sku()
# for each model given, get the llama_models.Model
# alternatively, if that search turns up nothing and the provider is HF, pass repo directly to hf_download
for model_id in model_ids:
hf_repo = None
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
# try just passing in as an HF repo
# if this is the case, set ignore_patterns to none
# many smaller models you want the .saftensor files.
args.ignore_patterns = None
hf_repo = model_id
else:
parser.error(f"Model {model_id} not found")
continue
else:
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
_hf_download(args.hf_token, args.ignore_patterns, parser, llama_model=model, hf_repo=hf_repo)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email "