cli updates

This commit is contained in:
Hardik Shah 2024-07-21 01:51:54 -07:00
parent 23fe353e4a
commit c9f33d8f68
5 changed files with 14 additions and 9 deletions

View file

@ -7,9 +7,10 @@ from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from toolchain.cli.subcommand import Subcommand
from toolchain.utils import DEFAULT_DUMP_DIR
DEFAULT_CHECKPOINT_DIR = f"{os.path.expanduser('~')}/.llama/checkpoints/"
DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints")
class Download(Subcommand):
@ -61,10 +62,8 @@ class Download(Subcommand):
def _run_download_cmd(self, args: argparse.Namespace):
model_name = args.repo_id.split("/")[-1]
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name
os.makedirs(output_dir, exist_ok=True)
output_dir = Path(output_dir) / model_name
try:
true_output_dir = snapshot_download(
args.repo_id,