From ea6d9ec93701006b955da2d8dbe56ab4d1cabead Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 26 Aug 2024 17:42:09 -0700 Subject: [PATCH] templates take optional --format={json,function_tag} --- llama_toolchain/cli/model/template.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index abdf20811..ca898e7a6 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -32,6 +32,16 @@ class ModelTemplate(Subcommand): self._add_arguments() self.parser.set_defaults(func=self._run_model_template_cmd) + def _prompt_type(self, value): + from llama_models.llama3.api.datatypes import ToolPromptFormat + + try: + return ToolPromptFormat(value.lower()) + except ValueError: + raise argparse.ArgumentTypeError( + f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}" + ) from None + def _add_arguments(self): self.parser.add_argument( "-m", @@ -46,6 +56,13 @@ class ModelTemplate(Subcommand): help="Usecase template name (system_message, user_message, assistant_message, tool_message)...", required=False, ) + self.parser.add_argument( + "--format", + type=str, + help="ToolPromptFormat ( json or functino_tag). This flag is used to print the template in a specific formats.", + required=False, + default="json", + ) def _run_model_template_cmd(self, args: argparse.Namespace) -> None: from llama_models.llama3.api.interface import ( @@ -56,7 +73,8 @@ class ModelTemplate(Subcommand): from llama_toolchain.cli.table import print_table if args.name: - template, tokens_info = render_jinja_template(args.name) + tool_prompt_format = self._prompt_type(args.format) + template, tokens_info = render_jinja_template(args.name, tool_prompt_format) rendered = "" for tok, is_special in tokens_info: if is_special: