From 1bc81eae7b34f7637c0ea35566def214ea2d08d7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 30 Jul 2024 17:52:57 -0700 Subject: [PATCH] update toolchain to work with updated imports from llama_models --- llama_toolchain/cli/download.py | 7 ++----- llama_toolchain/cli/model/describe.py | 3 ++- llama_toolchain/cli/model/list.py | 10 ++-------- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 63452a311..233573ed4 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -16,11 +16,8 @@ import httpx from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError -from llama_models.llama3_1.api.datatypes import ( - CheckpointQuantizationFormat, - ModelDefinition, -) -from llama_models.llama3_1.api.sku_list import ( +from llama_models.datatypes import CheckpointQuantizationFormat, ModelDefinition +from llama_models.sku_list import ( llama3_1_model_list, llama_meta_folder_path, llama_meta_pth_size, diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py index 687badd85..6551e6e65 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_toolchain/cli/model/describe.py @@ -9,7 +9,7 @@ import json from enum import Enum -from llama_models.llama3_1.api.sku_list import llama3_1_model_list +from llama_models.sku_list import llama3_1_model_list from termcolor import colored @@ -43,6 +43,7 @@ class ModelDescribe(Subcommand): "-m", "--model-id", type=str, + required=True, ) def _run_model_describe_cmd(self, args: argparse.Namespace) -> None: diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py index 0f1fff46d..9d26bb181 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_toolchain/cli/model/list.py @@ -6,7 +6,7 @@ import argparse -from llama_models.llama3_1.api.sku_list import llama3_1_model_list +from llama_models.sku_list import llama3_1_model_list from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table @@ -27,13 +27,7 @@ class ModelList(Subcommand): self.parser.set_defaults(func=self._run_model_list_cmd) def _add_arguments(self): - self.parser.add_argument( - "-m", - "--model-family", - type=str, - default="llama3_1", - help="Model Family (llama3_1, llama3_X, etc.)", - ) + pass def _run_model_list_cmd(self, args: argparse.Namespace) -> None: models = llama3_1_model_list()