Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

This commit is contained in:
Hardik Shah 2024-08-26 17:03:34 -07:00
parent decbbc127b
commit 69d9655ecd
4 changed files with 21 additions and 5 deletions

View file

@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
@ -279,6 +279,7 @@ class Llama:
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
if (
max_gen_len is None
@ -288,7 +289,10 @@ class Llama:
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,