Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

This commit is contained in:
Hardik Shah 2024-08-26 17:03:34 -07:00
parent decbbc127b
commit 69d9655ecd
4 changed files with 21 additions and 5 deletions

View file

@ -11,7 +11,7 @@ from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
@ -27,6 +27,7 @@ class InferenceArgs:
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ModelRunner:
@ -41,6 +42,7 @@ class ModelRunner:
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
tool_prompt_format=tool_prompt_format,
)
gen = self.group.run_inference(req_obj)