templates take optional --format={json,function_tag}

This commit is contained in:
Hardik Shah 2024-08-26 17:42:09 -07:00
parent 69d9655ecd
commit ea6d9ec937

View file

@ -32,6 +32,16 @@ class ModelTemplate(Subcommand):
self._add_arguments() self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd) self.parser.set_defaults(func=self._run_model_template_cmd)
def _prompt_type(self, value):
from llama_models.llama3.api.datatypes import ToolPromptFormat
try:
return ToolPromptFormat(value.lower())
except ValueError:
raise argparse.ArgumentTypeError(
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
) from None
def _add_arguments(self): def _add_arguments(self):
self.parser.add_argument( self.parser.add_argument(
"-m", "-m",
@ -46,6 +56,13 @@ class ModelTemplate(Subcommand):
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...", help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
required=False, required=False,
) )
self.parser.add_argument(
"--format",
type=str,
help="ToolPromptFormat ( json or functino_tag). This flag is used to print the template in a specific formats.",
required=False,
default="json",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None: def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3.api.interface import ( from llama_models.llama3.api.interface import (
@ -56,7 +73,8 @@ class ModelTemplate(Subcommand):
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
if args.name: if args.name:
template, tokens_info = render_jinja_template(args.name) tool_prompt_format = self._prompt_type(args.format)
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
rendered = "" rendered = ""
for tok, is_special in tokens_info: for tok, is_special in tokens_info:
if is_special: if is_special: