From 3583cf2d517b8bd91c88fc4f85e7696f2729608a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 26 Jul 2024 15:39:46 -0700 Subject: [PATCH] update model template output to be prettier, more consumable --- llama_toolchain/cli/model/template.py | 89 ++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index 498549d2a..9e3bee8ba 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -5,16 +5,23 @@ # the root directory of this source tree. import argparse +import re import textwrap from llama_models.llama3_1.api.interface import ( list_jinja_templates, render_jinja_template, ) +from termcolor import colored, cprint from llama_toolchain.cli.subcommand import Subcommand +def strip_ansi_colors(text): + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + class ModelTemplate(Subcommand): """Llama model cli for describe a model template (message formats)""" @@ -52,6 +59,84 @@ class ModelTemplate(Subcommand): def _run_model_template_cmd(self, args: argparse.Namespace) -> None: if args.template: - render_jinja_template(args.template) + template, tokens_info = render_jinja_template(args.template) + rendered = "" + for tok, is_special in tokens_info: + if is_special: + rendered += colored(tok, "yellow", attrs=["bold"]) + else: + rendered += tok + rendered += "\n" + print_table( + [ + ("Name", colored(template.template_name, "white", attrs=["bold"])), + ("Template", rendered), + ("Notes", template.notes), + ], + separate_rows=True, + ) else: - list_jinja_templates() + templates = list_jinja_templates() + headers = ["Role", "Template Name"] + print_table( + [(t.role, t.template_name) for t in templates], + headers, + ) + + +def format_row(row, col_widths): + def wrap(text, width): + lines = [] + for line in text.split("\n"): + if line.strip() == "": + lines.append("") + else: + line = line.strip() + lines.extend( + textwrap.wrap( + line, width, break_long_words=False, replace_whitespace=False + ) + ) + return lines + + wrapped = [wrap(item, width) for item, width in zip(row, col_widths)] + max_lines = max(len(subrow) for subrow in wrapped) + + lines = [] + for i in range(max_lines): + line = [] + for cell_lines, width in zip(wrapped, col_widths): + value = cell_lines[i] if i < len(cell_lines) else "" + line.append(value + " " * (width - len(strip_ansi_colors(value)))) + lines.append("| " + (" | ".join(line)) + " |") + + return "\n".join(lines) + + +def print_table(rows, headers=None, separate_rows: bool = False): + if not headers: + col_widths = [ + max(len(strip_ansi_colors(item)) for item in col) for col in zip(*rows) + ] + else: + col_widths = [ + max(len(header), max(len(strip_ansi_colors(item)) for item in col)) + for header, col in zip(headers, zip(*rows)) + ] + col_widths = [min(w, 80) for w in col_widths] + + header_line = "+".join("-" * (width + 2) for width in col_widths) + header_line = f"+{header_line}+" + + if headers: + print(header_line) + cprint(format_row(headers, col_widths), "white", attrs=["bold"]) + + print(header_line) + for row in rows: + print(format_row(row, col_widths)) + if separate_rows: + print(header_line) + + if not separate_rows: + print(header_line)