mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
5
llama_stack/models/__init__.py
Normal file
5
llama_stack/models/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
llama_stack/models/llama/__init__.py
Normal file
5
llama_stack/models/llama/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -231,6 +231,7 @@ class ModelFamily(Enum):
|
|||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
llama4 = "llama4"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
|
@ -272,6 +273,12 @@ class CoreModelId(Enum):
|
|||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Llama 4 family
|
||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
|
@ -332,6 +339,13 @@ def model_family(model_id) -> ModelFamily:
|
|||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
]:
|
||||
return ModelFamily.llama4
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
|
@ -379,6 +393,7 @@ class Model(BaseModel):
|
|||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
|
@ -396,6 +411,16 @@ class Model(BaseModel):
|
|||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama4:
|
||||
if self.core_model_id in {
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
}:
|
||||
return 262144
|
||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||
return 10485760
|
||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
return 1048576
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
|
|
|
@ -21,8 +21,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from ..prompt_format import (
|
||||
from llama_stack.models.llama.prompt_format import (
|
||||
# llama3_1_e2e_tool_call_dialog,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
|
|
5
llama_stack/models/llama/llama4/__init__.py
Normal file
5
llama_stack/models/llama/llama4/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
326
llama_stack/models/llama/llama4/chat_format.py
Normal file
326
llama_stack/models/llama/llama4/chat_format.py
Normal file
|
@ -0,0 +1,326 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
|
||||
ResizeNormalizeImageTransform,
|
||||
VariableSizeImageTransform,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
def role_str(role: Role) -> str:
|
||||
role_strs = {
|
||||
Role.user: "user",
|
||||
Role.system: "system",
|
||||
Role.tool: "ipython", # special
|
||||
Role.assistant: "assistant",
|
||||
}
|
||||
return role_strs[role]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformedImage:
|
||||
image_tiles: torch.Tensor
|
||||
# is the aspect ratio needed anywhere?
|
||||
aspect_ratio: Tuple[int, int]
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return new_img
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
vision_args: Optional[VisionArgs] = None,
|
||||
max_num_chunks: int = 16,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.vision_args = vision_args
|
||||
self.max_num_chunks = max_num_chunks
|
||||
|
||||
self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
|
||||
|
||||
self.image_transform = None
|
||||
self.dynamic_image_transform = None
|
||||
if vision_args:
|
||||
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
||||
self.image_transform = ResizeNormalizeImageTransform(
|
||||
vision_args.image_size.width, vision_args.image_size.height
|
||||
)
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
||||
|
||||
# TODO: need to check if this is correct
|
||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
|
||||
def encode_content(self, content: RawContent) -> LLMInput:
|
||||
tokens, images = self._encode_content(content, bos=True)
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
transformed_image: TransformedImage,
|
||||
) -> List[int]:
|
||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||
|
||||
image_tensor = transformed_image.image_tiles
|
||||
image_channels = image_tensor.shape[-3]
|
||||
image_height = image_tensor.shape[-2]
|
||||
image_width = image_tensor.shape[-1]
|
||||
image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
|
||||
|
||||
patch_height = self.vision_args.patch_size.height
|
||||
patch_width = self.vision_args.patch_size.width
|
||||
|
||||
if image_height % patch_height != 0:
|
||||
raise ValueError(f"{image_height=} not divisible by {patch_height=}")
|
||||
if image_width % patch_width != 0:
|
||||
raise ValueError(f"{image_width=} not divisible by {patch_width=}")
|
||||
|
||||
ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
|
||||
n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
|
||||
|
||||
image_ar = transformed_image.aspect_ratio
|
||||
tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
|
||||
if image_chunks == 1:
|
||||
tokens += [self.tokenizer.special_tokens["<|image|>"]]
|
||||
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
||||
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
|
||||
else:
|
||||
ratio_h, ratio_w = image_ar
|
||||
for _ in range(ratio_h):
|
||||
for xx in range(ratio_w):
|
||||
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
||||
if xx < ratio_w - 1:
|
||||
tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
|
||||
|
||||
tokens += [self.tokenizer.special_tokens["<|image|>"]]
|
||||
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
||||
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
|
||||
|
||||
return tokens
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
|
||||
tokens = []
|
||||
tranformed_images = []
|
||||
|
||||
added_bos = False
|
||||
|
||||
def _process(c):
|
||||
nonlocal added_bos, bos
|
||||
|
||||
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||
if isinstance(c, RawTextItem):
|
||||
c = c.text
|
||||
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||
added_bos = True
|
||||
|
||||
elif isinstance(c, RawMediaItem):
|
||||
if not self.vision_args:
|
||||
raise ValueError("The model is not vision-enabled, but a media item was found")
|
||||
|
||||
bos = False if added_bos else bos
|
||||
if bos:
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
added_bos = True
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = convert_rgba_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
if image_tiles.shape[0] > 1:
|
||||
image_global = self.image_transform(image)
|
||||
image_global = image_global.unsqueeze(0)
|
||||
image_combine = torch.cat((image_tiles, image_global), dim=0)
|
||||
image_tiles = image_combine
|
||||
|
||||
transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
|
||||
tokens.extend(self._encode_image(transformed_image))
|
||||
tranformed_images.append(transformed_image)
|
||||
|
||||
if isinstance(content, list):
|
||||
for c in content:
|
||||
_process(c)
|
||||
else:
|
||||
_process(content)
|
||||
|
||||
return tokens, tranformed_images
|
||||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[TransformedImage]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
def _process_content(c):
|
||||
toks, imgs = self._encode_content(c)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
if message.role == "assistant" and len(message.tool_calls) > 0:
|
||||
tokens.append(self.tokenizer.special_tokens["<|python_start|>"])
|
||||
|
||||
_process_content(message.content)
|
||||
|
||||
if message.role == "assistant" and len(message.tool_calls) > 0:
|
||||
tokens.append(self.tokenizer.special_tokens["<|python_end|>"])
|
||||
|
||||
if message.role == "user" and message.context is not None:
|
||||
# This is RAG context; why is it here in the chat format? I don't think
|
||||
# this is needed and can be moved upwards
|
||||
_process_content("\n\n")
|
||||
_process_content(message.context)
|
||||
|
||||
if message.role == "assistant":
|
||||
for t in message.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||
_process_content(content)
|
||||
|
||||
eom = False
|
||||
if message.role == "assistant":
|
||||
eom = message.stop_reason == StopReason.end_of_message
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
|
||||
return tokens, images
|
||||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> LLMInput:
|
||||
tokens = []
|
||||
images = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
for message in messages:
|
||||
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
# Add the start of an assistant message for the model to complete.
|
||||
tokens.extend(self._encode_header("assistant"))
|
||||
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||
content = content.strip(" ")
|
||||
header_str = self.possible_headers[Role.assistant]
|
||||
if content.startswith(header_str):
|
||||
content = content[len(header_str) :]
|
||||
|
||||
ipython = content.startswith("<|python_start|>")
|
||||
if ipython:
|
||||
content = content[len("<|python_start|>") :]
|
||||
content = content.replace("<|python_end|>", "")
|
||||
|
||||
if content.endswith("<|eot|>"):
|
||||
content = content[: -len("<|eot|>")]
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif content.endswith("<|eom|>"):
|
||||
content = content[: -len("<|eom|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
tool_name, query = builtin_tool_info
|
||||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
||||
tool_calls = []
|
||||
if tool_name is not None and tool_arguments is not None:
|
||||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
|
||||
return LLMInput(
|
||||
tokens=tokens,
|
||||
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
||||
)
|
277
llama_stack/models/llama/llama4/prompt_format.md
Normal file
277
llama_stack/models/llama/llama4/prompt_format.md
Normal file
File diff suppressed because one or more lines are too long
313
llama_stack/models/llama/llama4/prompts.py
Normal file
313
llama_stack/models/llama/llama4/prompts.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from llama_stack.models.llama.prompt_format import (
|
||||
Llama4UseCase,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
)
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def usecases(base_model: bool = False) -> List[UseCase | str]:
|
||||
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
|
||||
img_small_dog = f.read()
|
||||
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
|
||||
img_dog = f.read()
|
||||
with open(THIS_DIR.parent / "resources/pasta.jpeg", "rb") as f:
|
||||
img_pasta = f.read()
|
||||
out = []
|
||||
out.extend(
|
||||
[
|
||||
textwrap.dedent(
|
||||
"""
|
||||
# Llama 4 - Prompt Formats
|
||||
## Tokens
|
||||
Here is a list of special tokens that are supported by Llama 4:
|
||||
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||
- `<|header_start|>` and `<|header_end|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user and assistant].
|
||||
- `<|eot|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||
- at the end of a direct interaction between the model and the user
|
||||
- at the end of multiple interactions between the model and any available tools
|
||||
This token signals to the executor that the model has finished generating a response.
|
||||
- `<|image_start|>` and `<|image_end|>`: These tokens enclose the image data in the prompt.
|
||||
- `<|patch|>`: This token represents a piece of the tile/
|
||||
- `<|tile_y_separator|>` and `<|tile_x_separator|>`: These tokens are used to separate the y and x tiles of an image
|
||||
- `<|image|>`: In the new architecture, this token now separates the regular sized image information from a downsized version of it that fits in a single tile. The longer side is used for calculating the scale factor and the rest is padded to fit the tile.
|
||||
"""
|
||||
),
|
||||
textwrap.dedent(
|
||||
"""
|
||||
There are 3 different roles that are supported by Llama 4
|
||||
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
|
||||
"""
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if base_model:
|
||||
out.extend(
|
||||
[
|
||||
"# Llama 4 Base Model",
|
||||
Llama4UseCase(
|
||||
title="Text completion - Paris information",
|
||||
description="Text completion for Llama 4 base model uses this format.",
|
||||
dialogs=[TextCompletionContent(content="The capital of France is Paris")],
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Text completion - The color of the sky",
|
||||
description="Text completion for Llama 4 base model uses this format.",
|
||||
dialogs=[
|
||||
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be")
|
||||
],
|
||||
notes="",
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Text completion - Translation example",
|
||||
description="Text completion for Llama 4 base model uses this format.",
|
||||
dialogs=[
|
||||
TextCompletionContent(
|
||||
content="""apple is pomme,
|
||||
bannana is banane,
|
||||
cherry is"""
|
||||
)
|
||||
],
|
||||
notes="",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
out.extend(
|
||||
[
|
||||
"# Llama 4 Instruct Model",
|
||||
Llama4UseCase(
|
||||
title="Simple User and assistant conversation",
|
||||
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(role="system", content="You are a helpful assistant"),
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="Answer who are you in the form of jeopardy?",
|
||||
),
|
||||
]
|
||||
],
|
||||
notes="",
|
||||
max_gen_len=512,
|
||||
),
|
||||
"# Image prompt format",
|
||||
Llama4UseCase(
|
||||
title="Single image prompt format - small image",
|
||||
description="This example passes an image that is smaller than the tile size, to show the tile separator tokens are not needed",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content=[
|
||||
RawMediaItem(data=BytesIO(img_small_dog)),
|
||||
RawTextItem(text="Describe this image in two sentences"),
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
notes="""Notice the structure of the image section:
|
||||
```
|
||||
<|image_start|><|image|><|patch|>...<|patch|><|image_end|>
|
||||
```
|
||||
This is due to the image being smaller than the tile size.
|
||||
""",
|
||||
max_gen_len=512,
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Single image prompt format",
|
||||
description="Here is an example of how to pass an image to the model",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content=[
|
||||
RawMediaItem(data=BytesIO(img_dog)),
|
||||
RawTextItem(text="Describe this image in two sentences"),
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
notes="""With a bigger image, the image will include the tile separator tokens. Additionally, the image tag now separates a scaled down version of the image from the regular sized image.
|
||||
```
|
||||
<|image_start|><|patch|>...<|patch|><|tile_x_separator|><|patch|>...<|patch|><|tile_y_separator|><|patch|>...<|patch|><|image|><|patch|>...<|patch|><|image_end|>
|
||||
```
|
||||
""",
|
||||
max_gen_len=1024,
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Multiple images prompt format",
|
||||
description="Here is an example of how to pass an image to the model",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content=[
|
||||
RawMediaItem(data=BytesIO(img_dog)),
|
||||
RawMediaItem(data=BytesIO(img_pasta)),
|
||||
RawTextItem(text="Describe these images in two sentences"),
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
notes="With multiple images, each one is encapsulated in their corresponding image tags.",
|
||||
max_gen_len=4096,
|
||||
),
|
||||
"# Tool calling\nWe are continuing the format for zero shot function calling used in previous versions of Llama. All available functions can be provided either in the system message or in the user message.",
|
||||
Llama4UseCase(
|
||||
title="Zero shot function calling - system message",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="system",
|
||||
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
),
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="What is the weather in SF and Seattle?",
|
||||
),
|
||||
]
|
||||
],
|
||||
notes=textwrap.dedent(
|
||||
"""
|
||||
- The output supports multiple, and parallel tool calls natively
|
||||
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||
"""
|
||||
),
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Zero shot function calling - user message",
|
||||
description=textwrap.dedent(
|
||||
"""
|
||||
Similar to the above example, you can also provide information for all the available tools in the user message.
|
||||
"""
|
||||
),
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="""Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||
Here is a list of functions in JSON format that you can invoke:
|
||||
[
|
||||
{
|
||||
"name": "get_user_info",
|
||||
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"user_id"
|
||||
],
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||
},
|
||||
"special": {
|
||||
"type": "string",
|
||||
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||
"default": "none"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Should you decide to return the function call(s), put them in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||
|
||||
You SHOULD NOT include any other text in the response.""",
|
||||
),
|
||||
]
|
||||
],
|
||||
notes=textwrap.dedent(
|
||||
"""
|
||||
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||
"""
|
||||
),
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Tool calling with custom formats",
|
||||
description=textwrap.dedent(
|
||||
"""
|
||||
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||
"""
|
||||
),
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="""You have access to the following functions:\nUse the function 'trending_songs' to 'Returns the trending songs on a Music site':\n{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}\n\nThink very carefully before calling functions.\nIf you choose to call a function ONLY reply in the following format with no prefix or suffix:\n\n<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line<|eot_id|>""",
|
||||
),
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="Use tools to get latest trending songs",
|
||||
),
|
||||
]
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return out
|
200000
llama_stack/models/llama/llama4/tokenizer.model
Executable file
200000
llama_stack/models/llama/llama4/tokenizer.model
Executable file
File diff suppressed because it is too large
Load diff
255
llama_stack/models/llama/llama4/tokenizer.py
Normal file
255
llama_stack/models/llama/llama4/tokenizer.py
Normal file
|
@ -0,0 +1,255 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
def get_reserved_special_tokens(name, count, start_index=0):
|
||||
return [f"<|{name}_reserved_special_token_{i}|>" for i in range(start_index, start_index + count)]
|
||||
|
||||
|
||||
# 200005, ..., 200079
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
||||
"<|header_start|>",
|
||||
"<|header_end|>",
|
||||
"<|eom|>",
|
||||
"<|eot|>",
|
||||
"<|step|>",
|
||||
"<|text_post_train_reserved_special_token_0|>",
|
||||
"<|text_post_train_reserved_special_token_1|>",
|
||||
"<|text_post_train_reserved_special_token_2|>",
|
||||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|python_start|>",
|
||||
"<|python_end|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||
|
||||
# 200080, ..., 201133
|
||||
LLAMA4_VISION_SPECIAL_TOKENS = [
|
||||
"<|image_start|>",
|
||||
"<|image_end|>",
|
||||
"<|vision_reserved_special_token_0|>",
|
||||
"<|vision_reserved_special_token_1|>",
|
||||
"<|tile_x_separator|>",
|
||||
"<|tile_y_separator|>",
|
||||
"<|vision_reserved_special_token_2|>",
|
||||
"<|vision_reserved_special_token_3|>",
|
||||
"<|vision_reserved_special_token_4|>",
|
||||
"<|vision_reserved_special_token_5|>",
|
||||
"<|image|>",
|
||||
"<|vision_reserved_special_token_6|>",
|
||||
"<|patch|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"vision", 1041, 7
|
||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||
|
||||
|
||||
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
|
||||
|
||||
BASIC_SPECIAL_TOKENS = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_suffix|>",
|
||||
]
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 2048
|
||||
|
||||
O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
|
||||
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
||||
assert len(set(special_tokens)) == len(special_tokens)
|
||||
assert len(special_tokens) <= self.num_reserved_special_tokens
|
||||
|
||||
reserved_tokens = [
|
||||
f"<|reserved_special_token_{i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
special_tokens = special_tokens + reserved_tokens
|
||||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.O200K_PATTERN,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
|
||||
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
|
||||
self.pad_id: int = self.special_tokens["<|finetune_right_pad|>"]
|
||||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom|>"],
|
||||
self.special_tokens["<|eot|>"],
|
||||
]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
if allowed_special is None:
|
||||
allowed_special = set()
|
||||
assert type(s) is str
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
|
@ -27,6 +27,10 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
@ -46,6 +50,7 @@ class UseCase(BaseModel):
|
|||
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||
notes: str = ""
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||
max_gen_len: int = 512
|
||||
|
||||
def md_format(self):
|
||||
section = textwrap.dedent(
|
||||
|
@ -75,17 +80,16 @@ class UseCase(BaseModel):
|
|||
elif isinstance(dialog, TextCompletionContent):
|
||||
input_tokens, output_tokens = generator.text_completion_raw(
|
||||
dialog.content,
|
||||
max_gen_len=64,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_gen_len=64,
|
||||
)
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
max_gen_len=512,
|
||||
temperature=0.0,
|
||||
top_p=0.95,
|
||||
tool_prompt_format=self.tool_prompt_format,
|
||||
max_gen_len=self.max_gen_len,
|
||||
)
|
||||
text += "##### Input Prompt Format\n"
|
||||
|
||||
|
@ -115,6 +119,45 @@ class UseCase(BaseModel):
|
|||
return section
|
||||
|
||||
|
||||
class Llama4UseCase(UseCase):
|
||||
def dialogs_to_text(self, generator) -> str:
|
||||
def _code_block(text):
|
||||
return f"```\n{text}\n```"
|
||||
|
||||
text = ""
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
temperature = 0.0
|
||||
for dialog in self.dialogs:
|
||||
if isinstance(dialog, str):
|
||||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
# TODO pass the raw input and do the encoding in the text completion function
|
||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
||||
llm_input = LLMInput(tokens=input_tokens)
|
||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
||||
)
|
||||
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=temperature,
|
||||
max_gen_len=self.max_gen_len,
|
||||
)
|
||||
|
||||
text += "##### Input Prompt Format\n"
|
||||
text += _code_block(tokenizer.decode(input_tokens))
|
||||
text += "\n\n"
|
||||
text += "##### Model Response Format\n"
|
||||
text += _code_block(tokenizer.decode(output_tokens))
|
||||
text += "\n\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
|
|
BIN
llama_stack/models/llama/resources/dog.jpg
Normal file
BIN
llama_stack/models/llama/resources/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
BIN
llama_stack/models/llama/resources/pasta.jpeg
Normal file
BIN
llama_stack/models/llama/resources/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
BIN
llama_stack/models/llama/resources/small_dog.jpg
Normal file
BIN
llama_stack/models/llama/resources/small_dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
|
@ -19,6 +19,7 @@ from .datatypes import (
|
|||
CheckpointQuantizationFormat,
|
||||
CoreModelId,
|
||||
Model,
|
||||
ModelFamily,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
|
@ -36,7 +37,13 @@ def resolve_model(descriptor: str) -> Optional[Model]:
|
|||
|
||||
def all_registered_models() -> List[Model]:
|
||||
return (
|
||||
llama2_family() + llama3_family() + llama3_1_family() + llama3_2_family() + llama3_3_family() + safety_models()
|
||||
llama2_family()
|
||||
+ llama3_family()
|
||||
+ llama3_1_family()
|
||||
+ llama3_2_family()
|
||||
+ llama3_3_family()
|
||||
+ llama4_family()
|
||||
+ safety_models()
|
||||
)
|
||||
|
||||
|
||||
|
@ -83,6 +90,60 @@ def llama3_3_family() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama4_family() -> List[Model]:
|
||||
return [
|
||||
*llama4_base_models(),
|
||||
*llama4_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama4_base_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_scout_17b_16e,
|
||||
description="Llama 4 Scout (17b 16 experts model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Scout-17B-16E",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_maverick_17b_128e,
|
||||
description="Llama 4 Maverick (17b 128 experts model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def llama4_instruct_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
description="Llama 4 Scout (17b 16 experts instruct model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
description="Llama 4 Maverick (17b 128 experts instruct model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
description="Llama 4 Maverick (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
pth_file_count=8,
|
||||
variant="fp8",
|
||||
arch_args={},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def llama2_base_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
|
@ -989,12 +1050,24 @@ def llama_meta_pth_size(model: Model) -> int:
|
|||
if model.core_model_id not in (
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
):
|
||||
return 0
|
||||
|
||||
if model.pth_file_count == 16:
|
||||
return 51268302389
|
||||
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
return 60903742309
|
||||
else:
|
||||
return 101470976045
|
||||
if model.model_family == ModelFamily.llama3_1:
|
||||
if model.pth_file_count == 16:
|
||||
return 51268302389
|
||||
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
return 60903742309
|
||||
else:
|
||||
return 101470976045
|
||||
|
||||
if model.model_family == ModelFamily.llama4:
|
||||
if model.core_model_id == CoreModelId.llama4_maverick_17b_128e:
|
||||
return 100458118386
|
||||
elif model.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
return 54121549657
|
||||
else:
|
||||
return 100426653046
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue