Add llama model download alias for llama download

This commit is contained in:
Ashwin Bharambe 2024-08-07 16:10:26 -07:00
parent fddaf5c929
commit 57402c1a19
3 changed files with 65 additions and 45 deletions

View file

@ -30,46 +30,47 @@ class Download(Subcommand):
description="Download a model from llama.meta.comf or HuggingFace hub", description="Download a model from llama.meta.comf or HuggingFace hub",
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
self._add_arguments() setup_download_parser(self.parser)
self.parser.set_defaults(func=partial(run_download_cmd, parser=self.parser))
def _add_arguments(self):
from llama_models.sku_list import all_registered_models
models = all_registered_models() def setup_download_parser(parser: argparse.ArgumentParser) -> None:
self.parser.add_argument( from llama_models.sku_list import all_registered_models
"--source",
choices=["meta", "huggingface"], models = all_registered_models()
required=True, parser.add_argument(
) "--source",
self.parser.add_argument( choices=["meta", "huggingface"],
"--model-id", required=True,
choices=[x.descriptor() for x in models], )
required=True, parser.add_argument(
) "--model-id",
self.parser.add_argument( choices=[x.descriptor() for x in models],
"--hf-token", required=True,
type=str, )
required=False, parser.add_argument(
default=None, "--hf-token",
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.", type=str,
) required=False,
self.parser.add_argument( default=None,
"--meta-url", help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
type=str, )
required=False, parser.add_argument(
help="For source=meta, URL obtained from llama.meta.com after accepting license terms", "--meta-url",
) type=str,
self.parser.add_argument( required=False,
"--ignore-patterns", help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
type=str, )
required=False, parser.add_argument(
default="*.safetensors", "--ignore-patterns",
help=""" type=str,
required=False,
default="*.safetensors",
help="""
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
safetensors files to avoid downloading duplicate weights. safetensors files to avoid downloading duplicate weights.
""", """,
) )
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
def _hf_download( def _hf_download(

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
class ModelDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama model download",
description="Download a model from llama.meta.comf or HuggingFace hub",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_toolchain.cli.download import setup_download_parser
setup_download_parser(self.parser)

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import textwrap
from llama_toolchain.cli.model.describe import ModelDescribe from llama_toolchain.cli.model.describe import ModelDescribe
from llama_toolchain.cli.model.download import ModelDownload
from llama_toolchain.cli.model.list import ModelList from llama_toolchain.cli.model.list import ModelList
from llama_toolchain.cli.model.template import ModelTemplate from llama_toolchain.cli.model.template import ModelTemplate
@ -22,18 +22,13 @@ class ModelParser(Subcommand):
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(
"model", "model",
prog="llama model", prog="llama model",
description="Describe llama model interfaces", description="Work with llama models",
epilog=textwrap.dedent(
"""
Example:
llama model <subcommand> <options>
"""
),
) )
subparsers = self.parser.add_subparsers(title="model_subcommands") subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commandsa # Add sub-commands
ModelTemplate.create(subparsers) ModelDownload.create(subparsers)
ModelList.create(subparsers) ModelList.create(subparsers)
ModelTemplate.create(subparsers)
ModelDescribe.create(subparsers) ModelDescribe.create(subparsers)