mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
update generate prompt format
This commit is contained in:
parent
76004eacb4
commit
35aac86997
2 changed files with 36 additions and 62 deletions
|
@ -35,7 +35,6 @@ from .llama3.template_data import (
|
||||||
system_message_builtin_tools_only,
|
system_message_builtin_tools_only,
|
||||||
system_message_custom_tools_only,
|
system_message_custom_tools_only,
|
||||||
)
|
)
|
||||||
from .llama4.datatypes import LLMInput
|
|
||||||
|
|
||||||
|
|
||||||
class TextCompletionContent(BaseModel):
|
class TextCompletionContent(BaseModel):
|
||||||
|
@ -74,21 +73,22 @@ class UseCase(BaseModel):
|
||||||
text += dialog
|
text += dialog
|
||||||
text += "\n\n"
|
text += "\n\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif isinstance(dialog, TextCompletionContent):
|
|
||||||
input_tokens, output_tokens = generator.text_completion_raw(
|
|
||||||
dialog.content,
|
|
||||||
temperature=0.1,
|
|
||||||
top_p=0.95,
|
|
||||||
max_gen_len=64,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
batch = [dialog]
|
||||||
dialog,
|
method = (
|
||||||
temperature=0.0,
|
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||||
top_p=0.95,
|
|
||||||
max_gen_len=self.max_gen_len,
|
|
||||||
)
|
)
|
||||||
|
input_tokens = []
|
||||||
|
output_tokens = []
|
||||||
|
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
|
||||||
|
result = token_results[0]
|
||||||
|
if result.source == "input":
|
||||||
|
input_tokens.append(result.token)
|
||||||
|
else:
|
||||||
|
output_tokens.append(result.token)
|
||||||
|
|
||||||
|
if result.finished:
|
||||||
|
break
|
||||||
text += "##### Input Prompt Format\n"
|
text += "##### Input Prompt Format\n"
|
||||||
|
|
||||||
# FIXME: This is added to undo the hack in chat_formatter where
|
# FIXME: This is added to undo the hack in chat_formatter where
|
||||||
|
@ -124,27 +124,27 @@ class Llama4UseCase(UseCase):
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
temperature = 0.0
|
|
||||||
for dialog in self.dialogs:
|
for dialog in self.dialogs:
|
||||||
if isinstance(dialog, str):
|
if isinstance(dialog, str):
|
||||||
text += dialog
|
text += dialog
|
||||||
text += "\n\n"
|
text += "\n\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif isinstance(dialog, TextCompletionContent):
|
|
||||||
# TODO pass the raw input and do the encoding in the text completion function
|
|
||||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
|
||||||
llm_input = LLMInput(tokens=input_tokens)
|
|
||||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
|
||||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
batch = [dialog]
|
||||||
dialog,
|
method = (
|
||||||
temperature=temperature,
|
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||||
max_gen_len=self.max_gen_len,
|
|
||||||
)
|
)
|
||||||
|
input_tokens = []
|
||||||
|
output_tokens = []
|
||||||
|
for token_results in method(batch, echo=True, temperature=0.0):
|
||||||
|
result = token_results[0]
|
||||||
|
if result.source == "input":
|
||||||
|
input_tokens.append(result.token)
|
||||||
|
else:
|
||||||
|
output_tokens.append(result.token)
|
||||||
|
|
||||||
|
if result.finished:
|
||||||
|
break
|
||||||
|
|
||||||
text += "##### Input Prompt Format\n"
|
text += "##### Input Prompt Format\n"
|
||||||
text += _code_block(tokenizer.decode(input_tokens))
|
text += _code_block(tokenizer.decode(input_tokens))
|
||||||
|
|
|
@ -5,13 +5,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Run this script:
|
# Run this script:
|
||||||
# torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md
|
# torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md
|
||||||
|
|
||||||
|
@ -22,16 +15,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
|
from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.inline.inference.meta_reference.config import (
|
|
||||||
MetaReferenceInferenceConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference.llama3.generation import (
|
|
||||||
Llama3,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference.llama4.generation import (
|
|
||||||
Llama4,
|
|
||||||
)
|
|
||||||
|
|
||||||
THIS_DIR = Path(__file__).parent.resolve()
|
THIS_DIR = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
@ -50,24 +36,12 @@ def run_main(
|
||||||
if not llama_model:
|
if not llama_model:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
|
||||||
if not llama4:
|
cls = Llama4 if llama4 else Llama3
|
||||||
config = MetaReferenceInferenceConfig(
|
generator = cls.build(
|
||||||
model=model_id,
|
ckpt_dir=checkpoint_dir,
|
||||||
max_seq_len=4096,
|
max_seq_len=4096,
|
||||||
max_batch_size=1,
|
max_batch_size=1,
|
||||||
checkpoint_dir=checkpoint_dir,
|
)
|
||||||
)
|
|
||||||
generator = Llama3.build(
|
|
||||||
config=config,
|
|
||||||
model_id=model_id,
|
|
||||||
llama_model=llama_model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
generator = Llama4.build(
|
|
||||||
ckpt_dir=checkpoint_dir,
|
|
||||||
max_seq_len=4096,
|
|
||||||
max_batch_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cases = module.usecases()
|
use_cases = module.usecases()
|
||||||
text = ""
|
text = ""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue