forked from phoenix-oss/llama-stack-mirror
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -27,6 +27,10 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
@ -46,6 +50,7 @@ class UseCase(BaseModel):
|
|||
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||
notes: str = ""
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||
max_gen_len: int = 512
|
||||
|
||||
def md_format(self):
|
||||
section = textwrap.dedent(
|
||||
|
@ -75,17 +80,16 @@ class UseCase(BaseModel):
|
|||
elif isinstance(dialog, TextCompletionContent):
|
||||
input_tokens, output_tokens = generator.text_completion_raw(
|
||||
dialog.content,
|
||||
max_gen_len=64,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_gen_len=64,
|
||||
)
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
max_gen_len=512,
|
||||
temperature=0.0,
|
||||
top_p=0.95,
|
||||
tool_prompt_format=self.tool_prompt_format,
|
||||
max_gen_len=self.max_gen_len,
|
||||
)
|
||||
text += "##### Input Prompt Format\n"
|
||||
|
||||
|
@ -115,6 +119,45 @@ class UseCase(BaseModel):
|
|||
return section
|
||||
|
||||
|
||||
class Llama4UseCase(UseCase):
|
||||
def dialogs_to_text(self, generator) -> str:
|
||||
def _code_block(text):
|
||||
return f"```\n{text}\n```"
|
||||
|
||||
text = ""
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
temperature = 0.0
|
||||
for dialog in self.dialogs:
|
||||
if isinstance(dialog, str):
|
||||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
# TODO pass the raw input and do the encoding in the text completion function
|
||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
||||
llm_input = LLMInput(tokens=input_tokens)
|
||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
||||
)
|
||||
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=temperature,
|
||||
max_gen_len=self.max_gen_len,
|
||||
)
|
||||
|
||||
text += "##### Input Prompt Format\n"
|
||||
text += _code_block(tokenizer.decode(input_tokens))
|
||||
text += "\n\n"
|
||||
text += "##### Model Response Format\n"
|
||||
text += _code_block(tokenizer.decode(output_tokens))
|
||||
text += "\n\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue