Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

This commit is contained in:
Hardik Shah 2024-08-26 17:03:34 -07:00
parent decbbc127b
commit 69d9655ecd
4 changed files with 21 additions and 5 deletions

View file

@ -44,7 +44,12 @@ EventType = AgenticSystemTurnResponseEventType
class EventLogger:
async def log(self, event_generator, stream=True):
async def log(
self,
event_generator,
stream=True,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
):
previous_event_type = None
previous_step_type = None
@ -132,7 +137,9 @@ class EventLogger:
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(response.tool_calls[0])
content = ToolUtils.encode_tool_call(
response.tool_calls[0], tool_prompt_format
)
else:
content = response.content
yield event, LogEvent(

View file

@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
@ -279,6 +279,7 @@ class Llama:
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
if (
max_gen_len is None
@ -288,7 +289,10 @@ class Llama:
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,

View file

@ -104,6 +104,7 @@ class MetaReferenceInferenceImpl(Inference):
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)

View file

@ -11,7 +11,7 @@ from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
@ -27,6 +27,7 @@ class InferenceArgs:
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ModelRunner:
@ -41,6 +42,7 @@ class ModelRunner:
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
tool_prompt_format=tool_prompt_format,
)
gen = self.group.run_inference(req_obj)