diff --git a/llama_toolchain/agentic_system/event_logger.py b/llama_toolchain/agentic_system/event_logger.py index 6b8d034a6..3d15ee239 100644 --- a/llama_toolchain/agentic_system/event_logger.py +++ b/llama_toolchain/agentic_system/event_logger.py @@ -44,7 +44,12 @@ EventType = AgenticSystemTurnResponseEventType class EventLogger: - async def log(self, event_generator, stream=True): + async def log( + self, + event_generator, + stream=True, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + ): previous_event_type = None previous_step_type = None @@ -132,7 +137,9 @@ class EventLogger: if event_type == EventType.step_complete.value: response = event.payload.step_details.model_response if response.tool_calls: - content = ToolUtils.encode_tool_call(response.tool_calls[0]) + content = ToolUtils.encode_tool_call( + response.tool_calls[0], tool_prompt_format + ) else: content = response.content yield event, LogEvent( diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 058874702..1329f8699 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, ModelInput -from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.sku_list import resolve_model @@ -279,6 +279,7 @@ class Llama: top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> Generator: if ( max_gen_len is None @@ -288,7 +289,10 @@ class Llama: max_gen_len = self.model.params.max_seq_len - 1 yield from self.generate( - model_input=self.formatter.encode_dialog_prompt(messages), + model_input=self.formatter.encode_dialog_prompt( + messages, + tool_prompt_format, + ), max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 87ffc5226..72cb105ff 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -104,6 +104,7 @@ class MetaReferenceInferenceImpl(Inference): top_p=request.sampling_params.top_p, max_gen_len=request.sampling_params.max_tokens, logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, ): buffer += token_result.text tokens.append(token_result.token) diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index 3de4a6381..b5d81287b 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -11,7 +11,7 @@ from functools import partial from typing import Generator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model @@ -27,6 +27,7 @@ class InferenceArgs: top_p: float max_gen_len: int logprobs: bool + tool_prompt_format: ToolPromptFormat class ModelRunner: @@ -41,6 +42,7 @@ class ModelRunner: task.top_p, task.max_gen_len, task.logprobs, + task.tool_prompt_format, ) @@ -93,6 +95,7 @@ class LlamaModelParallelGenerator: top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> Generator: req_obj = InferenceArgs( messages=deepcopy(messages), @@ -100,6 +103,7 @@ class LlamaModelParallelGenerator: top_p=top_p, max_gen_len=max_gen_len, logprobs=logprobs, + tool_prompt_format=tool_prompt_format, ) gen = self.group.run_inference(req_obj)