Add model describe subcommand

This commit is contained in:
Ashwin Bharambe 2024-07-29 18:19:53 -07:00
parent 9d7f283722
commit 45b8a7ffcd
3 changed files with 77 additions and 6 deletions

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from llama_models.llama3_1.api.sku_list import llama3_1_model_list
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
class ModelDescribe(Subcommand):
"""Show details about a model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"describe",
prog="llama model describe",
description="Show details about a llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_describe_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-id",
type=str,
)
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
models = llama3_1_model_list()
by_id = {model.sku.value: model for model in models}
if args.model_id not in by_id:
print(
f"Model {args.model_id} not found; try 'llama model list' for a list of available models."
)
return
model = by_id[args.model_id]
rows = [
(
colored("Model", "white", attrs=["bold"]),
colored(model.sku.value, "white", attrs=["bold"]),
),
("HuggingFace ID", model.huggingface_id or "<Not Available>"),
("Description", model.description_markdown),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
("Weights format", model.quantization_format.value),
("Model params.json", json.dumps(model.model_args, indent=4)),
]
print_table(
rows,
separate_rows=True,
)

View file

@ -7,9 +7,10 @@
import argparse
import textwrap
from llama_toolchain.cli.model.describe import ModelDescribe
from llama_toolchain.cli.model.list import ModelList
from llama_toolchain.cli.model.template import ModelTemplate
from llama_toolchain.cli.subcommand import Subcommand
@ -35,3 +36,4 @@ class ModelParser(Subcommand):
# Add sub-commandsa
ModelTemplate.create(subparsers)
ModelList.create(subparsers)
ModelDescribe.create(subparsers)

View file

@ -22,7 +22,6 @@ def format_row(row, col_widths):
if line.strip() == "":
lines.append("")
else:
line = line.strip()
lines.extend(
textwrap.wrap(
line, width, break_long_words=False, replace_whitespace=False
@ -45,14 +44,18 @@ def format_row(row, col_widths):
def print_table(rows, headers=None, separate_rows: bool = False):
def itemlen(item):
return len(strip_ansi_colors(item))
rows = [[x or "" for x in row] for row in rows]
if not headers:
col_widths = [
max(len(strip_ansi_colors(item)) for item in col) for col in zip(*rows)
]
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
else:
col_widths = [
max(len(header), max(len(strip_ansi_colors(item)) for item in col))
max(
itemlen(header),
max(itemlen(item) for item in col),
)
for header, col in zip(headers, zip(*rows))
]
col_widths = [min(w, 80) for w in col_widths]