diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py new file mode 100644 index 000000000..0f1fff46d --- /dev/null +++ b/llama_toolchain/cli/model/list.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_models.llama3_1.api.sku_list import llama3_1_model_list + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.cli.table import print_table + + +class ModelList(Subcommand): + """List available llama models""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama model list", + description="Show available llama models", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_model_list_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "-m", + "--model-family", + type=str, + default="llama3_1", + help="Model Family (llama3_1, llama3_X, etc.)", + ) + + def _run_model_list_cmd(self, args: argparse.Namespace) -> None: + models = llama3_1_model_list() + headers = [ + "Model ID", + "HuggingFace ID", + "Context Length", + "Hardware Requirements", + ] + + rows = [] + for model in models: + req = model.hardware_requirements + rows.append( + [ + model.sku.value, + model.huggingface_id, + f"{model.max_seq_length // 1024}K", + f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM", + ] + ) + print_table( + rows, + headers, + separate_rows=True, + ) diff --git a/llama_toolchain/cli/model/model.py b/llama_toolchain/cli/model/model.py index 6af3ffa6b..09222437e 100644 --- a/llama_toolchain/cli/model/model.py +++ b/llama_toolchain/cli/model/model.py @@ -7,6 +7,8 @@ import argparse import textwrap +from llama_toolchain.cli.model.list import ModelList + from llama_toolchain.cli.model.template import ModelTemplate from llama_toolchain.cli.subcommand import Subcommand @@ -31,5 +33,5 @@ class ModelParser(Subcommand): subparsers = self.parser.add_subparsers(title="model_subcommands") # Add sub-commandsa - # ModelDescribe.create(subparsers) ModelTemplate.create(subparsers) + ModelList.create(subparsers) diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index 9e3bee8ba..911456a46 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -5,21 +5,16 @@ # the root directory of this source tree. import argparse -import re import textwrap from llama_models.llama3_1.api.interface import ( list_jinja_templates, render_jinja_template, ) -from termcolor import colored, cprint +from termcolor import colored from llama_toolchain.cli.subcommand import Subcommand - - -def strip_ansi_colors(text): - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) +from llama_toolchain.cli.table import print_table class ModelTemplate(Subcommand): @@ -82,61 +77,3 @@ class ModelTemplate(Subcommand): [(t.role, t.template_name) for t in templates], headers, ) - - -def format_row(row, col_widths): - def wrap(text, width): - lines = [] - for line in text.split("\n"): - if line.strip() == "": - lines.append("") - else: - line = line.strip() - lines.extend( - textwrap.wrap( - line, width, break_long_words=False, replace_whitespace=False - ) - ) - return lines - - wrapped = [wrap(item, width) for item, width in zip(row, col_widths)] - max_lines = max(len(subrow) for subrow in wrapped) - - lines = [] - for i in range(max_lines): - line = [] - for cell_lines, width in zip(wrapped, col_widths): - value = cell_lines[i] if i < len(cell_lines) else "" - line.append(value + " " * (width - len(strip_ansi_colors(value)))) - lines.append("| " + (" | ".join(line)) + " |") - - return "\n".join(lines) - - -def print_table(rows, headers=None, separate_rows: bool = False): - if not headers: - col_widths = [ - max(len(strip_ansi_colors(item)) for item in col) for col in zip(*rows) - ] - else: - col_widths = [ - max(len(header), max(len(strip_ansi_colors(item)) for item in col)) - for header, col in zip(headers, zip(*rows)) - ] - col_widths = [min(w, 80) for w in col_widths] - - header_line = "+".join("-" * (width + 2) for width in col_widths) - header_line = f"+{header_line}+" - - if headers: - print(header_line) - cprint(format_row(headers, col_widths), "white", attrs=["bold"]) - - print(header_line) - for row in rows: - print(format_row(row, col_widths)) - if separate_rows: - print(header_line) - - if not separate_rows: - print(header_line) diff --git a/llama_toolchain/cli/table.py b/llama_toolchain/cli/table.py new file mode 100644 index 000000000..07457eec8 --- /dev/null +++ b/llama_toolchain/cli/table.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import re +import textwrap + +from termcolor import cprint + + +def strip_ansi_colors(text): + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + +def format_row(row, col_widths): + def wrap(text, width): + lines = [] + for line in text.split("\n"): + if line.strip() == "": + lines.append("") + else: + line = line.strip() + lines.extend( + textwrap.wrap( + line, width, break_long_words=False, replace_whitespace=False + ) + ) + return lines + + wrapped = [wrap(item, width) for item, width in zip(row, col_widths)] + max_lines = max(len(subrow) for subrow in wrapped) + + lines = [] + for i in range(max_lines): + line = [] + for cell_lines, width in zip(wrapped, col_widths): + value = cell_lines[i] if i < len(cell_lines) else "" + line.append(value + " " * (width - len(strip_ansi_colors(value)))) + lines.append("| " + (" | ".join(line)) + " |") + + return "\n".join(lines) + + +def print_table(rows, headers=None, separate_rows: bool = False): + rows = [[x or "" for x in row] for row in rows] + if not headers: + col_widths = [ + max(len(strip_ansi_colors(item)) for item in col) for col in zip(*rows) + ] + else: + col_widths = [ + max(len(header), max(len(strip_ansi_colors(item)) for item in col)) + for header, col in zip(headers, zip(*rows)) + ] + col_widths = [min(w, 80) for w in col_widths] + + header_line = "+".join("-" * (width + 2) for width in col_widths) + header_line = f"+{header_line}+" + + if headers: + print(header_line) + cprint(format_row(headers, col_widths), "white", attrs=["bold"]) + + print(header_line) + for row in rows: + print(format_row(row, col_widths)) + if separate_rows: + print(header_line) + + if not separate_rows: + print(header_line)