mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
local imports for faster cli
This commit is contained in:
parent
af4710c959
commit
67229f23a4
9 changed files with 44 additions and 47 deletions
|
@ -13,15 +13,6 @@ from pathlib import Path
|
|||
|
||||
import httpx
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import (
|
||||
all_registered_models,
|
||||
llama_meta_net_info,
|
||||
resolve_model,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
|
@ -46,6 +37,8 @@ class Download(Subcommand):
|
|||
self.parser.set_defaults(func=self._run_download_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
models = all_registered_models()
|
||||
self.parser.add_argument(
|
||||
"--source",
|
||||
|
@ -81,7 +74,10 @@ safetensors files to avoid downloading duplicate weights.
|
|||
""",
|
||||
)
|
||||
|
||||
def _hf_download(self, model: Model, hf_token: str, ignore_patterns: str):
|
||||
def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
@ -112,7 +108,9 @@ safetensors files to avoid downloading duplicate weights.
|
|||
|
||||
print(f"Successfully downloaded model to {true_output_dir}")
|
||||
|
||||
def _meta_download(self, model: Model, meta_url: str):
|
||||
def _meta_download(self, model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
@ -128,6 +126,8 @@ safetensors files to avoid downloading duplicate weights.
|
|||
asyncio.run(downloader.download())
|
||||
|
||||
def _run_download_cmd(self, args: argparse.Namespace):
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
model = resolve_model(args.model_id)
|
||||
if model is None:
|
||||
self.parser.error(f"Model {args.model_id} not found")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue