diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py new file mode 100644 index 000000000..bb08dfc65 --- /dev/null +++ b/llama_toolchain/cli/model/describe.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json + +from llama_models.llama3_1.api.sku_list import llama3_1_model_list + +from termcolor import colored + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.cli.table import print_table + + +class ModelDescribe(Subcommand): + """Show details about a model""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "describe", + prog="llama model describe", + description="Show details about a llama model", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_model_describe_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "-m", + "--model-id", + type=str, + ) + + def _run_model_describe_cmd(self, args: argparse.Namespace) -> None: + models = llama3_1_model_list() + by_id = {model.sku.value: model for model in models} + + if args.model_id not in by_id: + print( + f"Model {args.model_id} not found; try 'llama model list' for a list of available models." + ) + return + + model = by_id[args.model_id] + + rows = [ + ( + colored("Model", "white", attrs=["bold"]), + colored(model.sku.value, "white", attrs=["bold"]), + ), + ("HuggingFace ID", model.huggingface_id or ""), + ("Description", model.description_markdown), + ("Context Length", f"{model.max_seq_length // 1024}K tokens"), + ("Weights format", model.quantization_format.value), + ("Model params.json", json.dumps(model.model_args, indent=4)), + ] + + print_table( + rows, + separate_rows=True, + ) diff --git a/llama_toolchain/cli/model/model.py b/llama_toolchain/cli/model/model.py index 09222437e..34cec3a67 100644 --- a/llama_toolchain/cli/model/model.py +++ b/llama_toolchain/cli/model/model.py @@ -7,9 +7,10 @@ import argparse import textwrap +from llama_toolchain.cli.model.describe import ModelDescribe from llama_toolchain.cli.model.list import ModelList - from llama_toolchain.cli.model.template import ModelTemplate + from llama_toolchain.cli.subcommand import Subcommand @@ -35,3 +36,4 @@ class ModelParser(Subcommand): # Add sub-commandsa ModelTemplate.create(subparsers) ModelList.create(subparsers) + ModelDescribe.create(subparsers) diff --git a/llama_toolchain/cli/table.py b/llama_toolchain/cli/table.py index 07457eec8..b63ae3467 100644 --- a/llama_toolchain/cli/table.py +++ b/llama_toolchain/cli/table.py @@ -22,7 +22,6 @@ def format_row(row, col_widths): if line.strip() == "": lines.append("") else: - line = line.strip() lines.extend( textwrap.wrap( line, width, break_long_words=False, replace_whitespace=False @@ -45,14 +44,18 @@ def format_row(row, col_widths): def print_table(rows, headers=None, separate_rows: bool = False): + def itemlen(item): + return len(strip_ansi_colors(item)) + rows = [[x or "" for x in row] for row in rows] if not headers: - col_widths = [ - max(len(strip_ansi_colors(item)) for item in col) for col in zip(*rows) - ] + col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)] else: col_widths = [ - max(len(header), max(len(strip_ansi_colors(item)) for item in col)) + max( + itemlen(header), + max(itemlen(item) for item in col), + ) for header, col in zip(headers, zip(*rows)) ] col_widths = [min(w, 80) for w in col_widths]