mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
templates take optional --format={json,function_tag}
This commit is contained in:
parent
69d9655ecd
commit
ea6d9ec937
1 changed files with 19 additions and 1 deletions
|
@ -32,6 +32,16 @@ class ModelTemplate(Subcommand):
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
self.parser.set_defaults(func=self._run_model_template_cmd)
|
self.parser.set_defaults(func=self._run_model_template_cmd)
|
||||||
|
|
||||||
|
def _prompt_type(self, value):
|
||||||
|
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ToolPromptFormat(value.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
|
||||||
|
) from None
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"-m",
|
"-m",
|
||||||
|
@ -46,6 +56,13 @@ class ModelTemplate(Subcommand):
|
||||||
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
|
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
type=str,
|
||||||
|
help="ToolPromptFormat ( json or functino_tag). This flag is used to print the template in a specific formats.",
|
||||||
|
required=False,
|
||||||
|
default="json",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_models.llama3.api.interface import (
|
from llama_models.llama3.api.interface import (
|
||||||
|
@ -56,7 +73,8 @@ class ModelTemplate(Subcommand):
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_toolchain.cli.table import print_table
|
||||||
|
|
||||||
if args.name:
|
if args.name:
|
||||||
template, tokens_info = render_jinja_template(args.name)
|
tool_prompt_format = self._prompt_type(args.format)
|
||||||
|
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
|
||||||
rendered = ""
|
rendered = ""
|
||||||
for tok, is_special in tokens_info:
|
for tok, is_special in tokens_info:
|
||||||
if is_special:
|
if is_special:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue