From 35aac869978e0e24baf1c2e535cd9d6b60e4e559 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Apr 2025 14:54:43 -0700 Subject: [PATCH] update generate prompt format --- llama_stack/models/llama/prompt_format.py | 56 +++++++++++------------ scripts/generate_prompt_format.py | 42 ++++------------- 2 files changed, 36 insertions(+), 62 deletions(-) diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py index 6756aebfe..edb34620c 100644 --- a/llama_stack/models/llama/prompt_format.py +++ b/llama_stack/models/llama/prompt_format.py @@ -35,7 +35,6 @@ from .llama3.template_data import ( system_message_builtin_tools_only, system_message_custom_tools_only, ) -from .llama4.datatypes import LLMInput class TextCompletionContent(BaseModel): @@ -74,21 +73,22 @@ class UseCase(BaseModel): text += dialog text += "\n\n" continue - - elif isinstance(dialog, TextCompletionContent): - input_tokens, output_tokens = generator.text_completion_raw( - dialog.content, - temperature=0.1, - top_p=0.95, - max_gen_len=64, - ) else: - input_tokens, output_tokens = generator.chat_completion_raw( - dialog, - temperature=0.0, - top_p=0.95, - max_gen_len=self.max_gen_len, + batch = [dialog] + method = ( + generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion ) + input_tokens = [] + output_tokens = [] + for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95): + result = token_results[0] + if result.source == "input": + input_tokens.append(result.token) + else: + output_tokens.append(result.token) + + if result.finished: + break text += "##### Input Prompt Format\n" # FIXME: This is added to undo the hack in chat_formatter where @@ -124,27 +124,27 @@ class Llama4UseCase(UseCase): text = "" tokenizer = Tokenizer.get_instance() - temperature = 0.0 for dialog in self.dialogs: if isinstance(dialog, str): text += dialog text += "\n\n" continue - - elif isinstance(dialog, TextCompletionContent): - # TODO pass the raw input and do the encoding in the text completion function - input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False) - llm_input = LLMInput(tokens=input_tokens) - output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw( - llm_input, temperature=temperature, max_gen_len=self.max_gen_len - ) - else: - input_tokens, output_tokens = generator.chat_completion_raw( - dialog, - temperature=temperature, - max_gen_len=self.max_gen_len, + batch = [dialog] + method = ( + generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion ) + input_tokens = [] + output_tokens = [] + for token_results in method(batch, echo=True, temperature=0.0): + result = token_results[0] + if result.source == "input": + input_tokens.append(result.token) + else: + output_tokens.append(result.token) + + if result.finished: + break text += "##### Input Prompt Format\n" text += _code_block(tokenizer.decode(input_tokens)) diff --git a/scripts/generate_prompt_format.py b/scripts/generate_prompt_format.py index 08c5bea22..5598e35f6 100755 --- a/scripts/generate_prompt_format.py +++ b/scripts/generate_prompt_format.py @@ -5,13 +5,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - # Run this script: # torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md @@ -22,16 +15,9 @@ from pathlib import Path import fire +from llama_stack.models.llama.llama3.generation import Llama3 +from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.inline.inference.meta_reference.config import ( - MetaReferenceInferenceConfig, -) -from llama_stack.providers.inline.inference.meta_reference.llama3.generation import ( - Llama3, -) -from llama_stack.providers.inline.inference.meta_reference.llama4.generation import ( - Llama4, -) THIS_DIR = Path(__file__).parent.resolve() @@ -50,24 +36,12 @@ def run_main( if not llama_model: raise ValueError(f"Model {model_id} not found") - if not llama4: - config = MetaReferenceInferenceConfig( - model=model_id, - max_seq_len=4096, - max_batch_size=1, - checkpoint_dir=checkpoint_dir, - ) - generator = Llama3.build( - config=config, - model_id=model_id, - llama_model=llama_model, - ) - else: - generator = Llama4.build( - ckpt_dir=checkpoint_dir, - max_seq_len=4096, - max_batch_size=1, - ) + cls = Llama4 if llama4 else Llama3 + generator = cls.build( + ckpt_dir=checkpoint_dir, + max_seq_len=4096, + max_batch_size=1, + ) use_cases = module.usecases() text = ""