update toolchain to work with updated imports from llama_models

This commit is contained in:
Ashwin Bharambe 2024-07-30 17:52:57 -07:00
parent 23014ea4d1
commit 1bc81eae7b
3 changed files with 6 additions and 14 deletions

View file

@ -16,11 +16,8 @@ import httpx
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_models.llama3_1.api.datatypes import ( from llama_models.datatypes import CheckpointQuantizationFormat, ModelDefinition
CheckpointQuantizationFormat, from llama_models.sku_list import (
ModelDefinition,
)
from llama_models.llama3_1.api.sku_list import (
llama3_1_model_list, llama3_1_model_list,
llama_meta_folder_path, llama_meta_folder_path,
llama_meta_pth_size, llama_meta_pth_size,

View file

@ -9,7 +9,7 @@ import json
from enum import Enum from enum import Enum
from llama_models.llama3_1.api.sku_list import llama3_1_model_list from llama_models.sku_list import llama3_1_model_list
from termcolor import colored from termcolor import colored
@ -43,6 +43,7 @@ class ModelDescribe(Subcommand):
"-m", "-m",
"--model-id", "--model-id",
type=str, type=str,
required=True,
) )
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None: def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:

View file

@ -6,7 +6,7 @@
import argparse import argparse
from llama_models.llama3_1.api.sku_list import llama3_1_model_list from llama_models.sku_list import llama3_1_model_list
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
@ -27,13 +27,7 @@ class ModelList(Subcommand):
self.parser.set_defaults(func=self._run_model_list_cmd) self.parser.set_defaults(func=self._run_model_list_cmd)
def _add_arguments(self): def _add_arguments(self):
self.parser.add_argument( pass
"-m",
"--model-family",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None: def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
models = llama3_1_model_list() models = llama3_1_model_list()