mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
refactor: support downloading any model from HF
given the work being done to support non-llama models, the download utility should be able to take any `hf_repo/model` to download a qualified model from HF. While the model might not be able to be used quite yet in llama stack directly, its helpful to have a utility that can download any and all models Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
19ae4b35d9
commit
6608c7fed9
1 changed files with 31 additions and 11 deletions
|
@ -311,25 +311,33 @@ class ParallelDownloader:
|
||||||
|
|
||||||
|
|
||||||
def _hf_download(
|
def _hf_download(
|
||||||
model: "Model",
|
|
||||||
hf_token: str,
|
hf_token: str,
|
||||||
ignore_patterns: str,
|
ignore_patterns: str,
|
||||||
parser: argparse.ArgumentParser,
|
parser: argparse.ArgumentParser,
|
||||||
|
llama_model: Optional["Model"] | None = None,
|
||||||
|
hf_repo: Optional[str] | None = None,
|
||||||
):
|
):
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
repo_id = model.huggingface_repo
|
# if we have a llama_model, meaning this model was found in the
|
||||||
if repo_id is None:
|
# llama_models impl, replace the hf_repo (None) with this value.
|
||||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
if llama_model:
|
||||||
|
hf_repo = llama_model.huggingface_repo
|
||||||
|
# if the user did not pass a valid hf_repo or llama_model, error.
|
||||||
|
if hf_repo is None:
|
||||||
|
raise ValueError(f"No repo id found for model {llama_model.descriptor()}")
|
||||||
|
|
||||||
output_dir = model_local_dir(model.descriptor())
|
if llama_model:
|
||||||
|
output_dir = model_local_dir(llama_model.descriptor())
|
||||||
|
else:
|
||||||
|
output_dir = model_local_dir(hf_repo)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
true_output_dir = snapshot_download(
|
true_output_dir = snapshot_download(
|
||||||
repo_id,
|
hf_repo,
|
||||||
local_dir=output_dir,
|
local_dir=output_dir,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
token=hf_token,
|
token=hf_token,
|
||||||
|
@ -343,7 +351,7 @@ def _hf_download(
|
||||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||||
)
|
)
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
|
parser.error(f"Repository '{hf_repo}' not found on the Hugging Face Hub.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
parser.error(e)
|
parser.error(e)
|
||||||
|
|
||||||
|
@ -461,19 +469,31 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard = prompt_guard_model_sku()
|
||||||
|
|
||||||
|
# for each model given, get the llama_models.Model
|
||||||
|
# alternatively, if that search turns up nothing and the provider is HF, pass repo directly to hf_download
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
|
hf_repo = None
|
||||||
if model_id == prompt_guard.model_id:
|
if model_id == prompt_guard.model_id:
|
||||||
model = prompt_guard
|
model = prompt_guard
|
||||||
info = prompt_guard_download_info()
|
info = prompt_guard_download_info()
|
||||||
else:
|
else:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
parser.error(f"Model {model_id} not found")
|
if args.source == "huggingface":
|
||||||
continue
|
# try just passing in as an HF repo
|
||||||
info = llama_meta_net_info(model)
|
# if this is the case, set ignore_patterns to none
|
||||||
|
# many smaller models you want the .saftensor files.
|
||||||
|
args.ignore_patterns = None
|
||||||
|
hf_repo = model_id
|
||||||
|
else:
|
||||||
|
parser.error(f"Model {model_id} not found")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
info = llama_meta_net_info(model)
|
||||||
|
|
||||||
if args.source == "huggingface":
|
if args.source == "huggingface":
|
||||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
_hf_download(args.hf_token, args.ignore_patterns, parser, llama_model=model, hf_repo=hf_repo)
|
||||||
else:
|
else:
|
||||||
meta_url = args.meta_url or input(
|
meta_url = args.meta_url or input(
|
||||||
f"Please provide the signed URL for model {model_id} you received via email "
|
f"Please provide the signed URL for model {model_id} you received via email "
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue