mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 15:03:54 +00:00
update generate prompt format
This commit is contained in:
parent
76004eacb4
commit
35aac86997
2 changed files with 36 additions and 62 deletions
|
|
@ -35,7 +35,6 @@ from .llama3.template_data import (
|
|||
system_message_builtin_tools_only,
|
||||
system_message_custom_tools_only,
|
||||
)
|
||||
from .llama4.datatypes import LLMInput
|
||||
|
||||
|
||||
class TextCompletionContent(BaseModel):
|
||||
|
|
@ -74,21 +73,22 @@ class UseCase(BaseModel):
|
|||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
input_tokens, output_tokens = generator.text_completion_raw(
|
||||
dialog.content,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_gen_len=64,
|
||||
)
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=0.0,
|
||||
top_p=0.95,
|
||||
max_gen_len=self.max_gen_len,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
text += "##### Input Prompt Format\n"
|
||||
|
||||
# FIXME: This is added to undo the hack in chat_formatter where
|
||||
|
|
@ -124,27 +124,27 @@ class Llama4UseCase(UseCase):
|
|||
|
||||
text = ""
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
temperature = 0.0
|
||||
for dialog in self.dialogs:
|
||||
if isinstance(dialog, str):
|
||||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
# TODO pass the raw input and do the encoding in the text completion function
|
||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
||||
llm_input = LLMInput(tokens=input_tokens)
|
||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
||||
)
|
||||
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=temperature,
|
||||
max_gen_len=self.max_gen_len,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.0):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
|
||||
text += "##### Input Prompt Format\n"
|
||||
text += _code_block(tokenizer.decode(input_tokens))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue