From 57402c1a196235d561a1997e60355d9a00de22fd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 7 Aug 2024 16:10:26 -0700 Subject: [PATCH] Add `llama model download` alias for `llama download` --- llama_toolchain/cli/download.py | 71 ++++++++++++++------------- llama_toolchain/cli/model/download.py | 24 +++++++++ llama_toolchain/cli/model/model.py | 15 ++---- 3 files changed, 65 insertions(+), 45 deletions(-) create mode 100644 llama_toolchain/cli/model/download.py diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 401bc633c..02cbe40a1 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -30,46 +30,47 @@ class Download(Subcommand): description="Download a model from llama.meta.comf or HuggingFace hub", formatter_class=argparse.RawTextHelpFormatter, ) - self._add_arguments() - self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser)) + setup_download_parser(self.parser) - def _add_arguments(self): - from llama_models.sku_list import all_registered_models - models = all_registered_models() - self.parser.add_argument( - "--source", - choices=["meta", "huggingface"], - required=True, - ) - self.parser.add_argument( - "--model-id", - choices=[x.descriptor() for x in models], - required=True, - ) - self.parser.add_argument( - "--hf-token", - type=str, - required=False, - default=None, - help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.", - ) - self.parser.add_argument( - "--meta-url", - type=str, - required=False, - help="For source=meta, URL obtained from llama.meta.com after accepting license terms", - ) - self.parser.add_argument( - "--ignore-patterns", - type=str, - required=False, - default="*.safetensors", - help=""" +def setup_download_parser(parser: argparse.ArgumentParser) -> None: + from llama_models.sku_list import all_registered_models + + models = all_registered_models() + parser.add_argument( + "--source", + choices=["meta", "huggingface"], + required=True, + ) + parser.add_argument( + "--model-id", + choices=[x.descriptor() for x in models], + required=True, + ) + parser.add_argument( + "--hf-token", + type=str, + required=False, + default=None, + help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.", + ) + parser.add_argument( + "--meta-url", + type=str, + required=False, + help="For source=meta, URL obtained from llama.meta.com after accepting license terms", + ) + parser.add_argument( + "--ignore-patterns", + type=str, + required=False, + default="*.safetensors", + help=""" For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights. """, - ) + ) + parser.set_defaults(func=partial(run_download_cmd, parser=parser)) def _hf_download( diff --git a/llama_toolchain/cli/model/download.py b/llama_toolchain/cli/model/download.py new file mode 100644 index 000000000..f133c1c6c --- /dev/null +++ b/llama_toolchain/cli/model/download.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_toolchain.cli.subcommand import Subcommand + + +class ModelDownload(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "download", + prog="llama model download", + description="Download a model from llama.meta.comf or HuggingFace hub", + formatter_class=argparse.RawTextHelpFormatter, + ) + + from llama_toolchain.cli.download import setup_download_parser + + setup_download_parser(self.parser) diff --git a/llama_toolchain/cli/model/model.py b/llama_toolchain/cli/model/model.py index 34cec3a67..9a14450ad 100644 --- a/llama_toolchain/cli/model/model.py +++ b/llama_toolchain/cli/model/model.py @@ -5,9 +5,9 @@ # the root directory of this source tree. import argparse -import textwrap from llama_toolchain.cli.model.describe import ModelDescribe +from llama_toolchain.cli.model.download import ModelDownload from llama_toolchain.cli.model.list import ModelList from llama_toolchain.cli.model.template import ModelTemplate @@ -22,18 +22,13 @@ class ModelParser(Subcommand): self.parser = subparsers.add_parser( "model", prog="llama model", - description="Describe llama model interfaces", - epilog=textwrap.dedent( - """ - Example: - llama model - """ - ), + description="Work with llama models", ) subparsers = self.parser.add_subparsers(title="model_subcommands") - # Add sub-commandsa - ModelTemplate.create(subparsers) + # Add sub-commands + ModelDownload.create(subparsers) ModelList.create(subparsers) + ModelTemplate.create(subparsers) ModelDescribe.create(subparsers)