mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
refactor: support downloading any model from HF
given the work being done to support non-llama models, the download utility should be able to take any `hf_repo/model` to download a qualified model from HF. While the model might not be able to be used quite yet in llama stack directly, its helpful to have a utility that can download any and all models Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
19ae4b35d9
commit
6608c7fed9
1 changed files with 31 additions and 11 deletions
|
@ -311,25 +311,33 @@ class ParallelDownloader:
|
|||
|
||||
|
||||
def _hf_download(
|
||||
model: "Model",
|
||||
hf_token: str,
|
||||
ignore_patterns: str,
|
||||
parser: argparse.ArgumentParser,
|
||||
llama_model: Optional["Model"] | None = None,
|
||||
hf_repo: Optional[str] | None = None,
|
||||
):
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
# if we have a llama_model, meaning this model was found in the
|
||||
# llama_models impl, replace the hf_repo (None) with this value.
|
||||
if llama_model:
|
||||
hf_repo = llama_model.huggingface_repo
|
||||
# if the user did not pass a valid hf_repo or llama_model, error.
|
||||
if hf_repo is None:
|
||||
raise ValueError(f"No repo id found for model {llama_model.descriptor()}")
|
||||
|
||||
output_dir = model_local_dir(model.descriptor())
|
||||
if llama_model:
|
||||
output_dir = model_local_dir(llama_model.descriptor())
|
||||
else:
|
||||
output_dir = model_local_dir(hf_repo)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
repo_id,
|
||||
hf_repo,
|
||||
local_dir=output_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
token=hf_token,
|
||||
|
@ -343,7 +351,7 @@ def _hf_download(
|
|||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
|
||||
parser.error(f"Repository '{hf_repo}' not found on the Hugging Face Hub.")
|
||||
except Exception as e:
|
||||
parser.error(e)
|
||||
|
||||
|
@ -461,19 +469,31 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
)
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
|
||||
# for each model given, get the llama_models.Model
|
||||
# alternatively, if that search turns up nothing and the provider is HF, pass repo directly to hf_download
|
||||
for model_id in model_ids:
|
||||
hf_repo = None
|
||||
if model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
info = llama_meta_net_info(model)
|
||||
if args.source == "huggingface":
|
||||
# try just passing in as an HF repo
|
||||
# if this is the case, set ignore_patterns to none
|
||||
# many smaller models you want the .saftensor files.
|
||||
args.ignore_patterns = None
|
||||
hf_repo = model_id
|
||||
else:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
else:
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
_hf_download(args.hf_token, args.ignore_patterns, parser, llama_model=model, hf_repo=hf_repo)
|
||||
else:
|
||||
meta_url = args.meta_url or input(
|
||||
f"Please provide the signed URL for model {model_id} you received via email "
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue