mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
templates take optional --format={json,function_tag}
This commit is contained in:
parent
69d9655ecd
commit
ea6d9ec937
1 changed files with 19 additions and 1 deletions
|
@ -32,6 +32,16 @@ class ModelTemplate(Subcommand):
|
|||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_template_cmd)
|
||||
|
||||
def _prompt_type(self, value):
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
|
||||
try:
|
||||
return ToolPromptFormat(value.lower())
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
|
||||
) from None
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
|
@ -46,6 +56,13 @@ class ModelTemplate(Subcommand):
|
|||
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
|
||||
required=False,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
help="ToolPromptFormat ( json or functino_tag). This flag is used to print the template in a specific formats.",
|
||||
required=False,
|
||||
default="json",
|
||||
)
|
||||
|
||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_models.llama3.api.interface import (
|
||||
|
@ -56,7 +73,8 @@ class ModelTemplate(Subcommand):
|
|||
from llama_toolchain.cli.table import print_table
|
||||
|
||||
if args.name:
|
||||
template, tokens_info = render_jinja_template(args.name)
|
||||
tool_prompt_format = self._prompt_type(args.format)
|
||||
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
|
||||
rendered = ""
|
||||
for tok, is_special in tokens_info:
|
||||
if is_special:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue