feat: introduce llama4 support (#1877)

As title says. Details in README, elsewhere.
This commit is contained in:
Ashwin Bharambe 2025-04-05 11:53:35 -07:00 committed by GitHub
parent 23a99a4b22
commit b8f1561956
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 205222 additions and 6439 deletions

View file

@ -162,6 +162,10 @@ class ParallelDownloader:
raise last_exception
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
if task.total_size > 0:
self.progress.update(task.task_id, total=task.total_size)
return
async def _get_info():
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response.raise_for_status()
@ -282,7 +286,7 @@ class ParallelDownloader:
if not tasks:
raise ValueError("No download tasks provided")
if not self.has_disk_space(tasks):
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []

View file

@ -231,6 +231,7 @@ class ModelFamily(Enum):
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
@ -272,6 +273,12 @@ class CoreModelId(Enum):
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
@ -332,6 +339,13 @@ def model_family(model_id) -> ModelFamily:
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
@ -379,6 +393,7 @@ class Model(BaseModel):
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@ -396,6 +411,16 @@ class Model(BaseModel):
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,

View file

@ -21,8 +21,7 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from ..prompt_format import (
from llama_stack.models.llama.prompt_format import (
# llama3_1_e2e_tool_call_dialog,
TextCompletionContent,
UseCase,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,326 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
ResizeNormalizeImageTransform,
VariableSizeImageTransform,
)
from .tokenizer import Tokenizer
def role_str(role: Role) -> str:
role_strs = {
Role.user: "user",
Role.system: "system",
Role.tool: "ipython", # special
Role.assistant: "assistant",
}
return role_strs[role]
@dataclass
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: Tuple[int, int]
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img
return image.convert("RGB")
class ChatFormat:
possible_headers: Dict[Role, str]
def __init__(
self,
tokenizer: Tokenizer,
vision_args: Optional[VisionArgs] = None,
max_num_chunks: int = 16,
):
self.tokenizer = tokenizer
self.vision_args = vision_args
self.max_num_chunks = max_num_chunks
self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
self.image_transform = None
self.dynamic_image_transform = None
if vision_args:
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height
)
def _encode_header(self, role: str) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
# TODO: need to check if this is correct
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_content(self, content: RawContent) -> LLMInput:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_image(
self,
transformed_image: TransformedImage,
) -> List[int]:
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
image_channels = image_tensor.shape[-3]
image_height = image_tensor.shape[-2]
image_width = image_tensor.shape[-1]
image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
patch_height = self.vision_args.patch_size.height
patch_width = self.vision_args.patch_size.width
if image_height % patch_height != 0:
raise ValueError(f"{image_height=} not divisible by {patch_height=}")
if image_width % patch_width != 0:
raise ValueError(f"{image_width=} not divisible by {patch_width=}")
ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
image_ar = transformed_image.aspect_ratio
tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
if image_chunks == 1:
tokens += [self.tokenizer.special_tokens["<|image|>"]]
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
else:
ratio_h, ratio_w = image_ar
for _ in range(ratio_h):
for xx in range(ratio_w):
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
if xx < ratio_w - 1:
tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
tokens += [self.tokenizer.special_tokens["<|image|>"]]
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
return tokens
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
tokens = []
tranformed_images = []
added_bos = False
def _process(c):
nonlocal added_bos, bos
if isinstance(c, str) or isinstance(c, RawTextItem):
if isinstance(c, RawTextItem):
c = c.text
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
added_bos = True
elif isinstance(c, RawMediaItem):
if not self.vision_args:
raise ValueError("The model is not vision-enabled, but a media item was found")
bos = False if added_bos else bos
if bos:
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = convert_rgba_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:
image_global = self.image_transform(image)
image_global = image_global.unsqueeze(0)
image_combine = torch.cat((image_tiles, image_global), dim=0)
image_tiles = image_combine
transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
tokens.extend(self._encode_image(transformed_image))
tranformed_images.append(transformed_image)
if isinstance(content, list):
for c in content:
_process(c)
else:
_process(content)
return tokens, tranformed_images
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[TransformedImage]]:
tokens = self._encode_header(message.role)
images = []
def _process_content(c):
toks, imgs = self._encode_content(c)
tokens.extend(toks)
images.extend(imgs)
if message.role == "assistant" and len(message.tool_calls) > 0:
tokens.append(self.tokenizer.special_tokens["<|python_start|>"])
_process_content(message.content)
if message.role == "assistant" and len(message.tool_calls) > 0:
tokens.append(self.tokenizer.special_tokens["<|python_end|>"])
if message.role == "user" and message.context is not None:
# This is RAG context; why is it here in the chat format? I don't think
# this is needed and can be moved upwards
_process_content("\n\n")
_process_content(message.context)
if message.role == "assistant":
for t in message.tool_calls:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> LLMInput:
tokens = []
images = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in messages:
toks, imgs = self.encode_message(message, tool_prompt_format)
tokens.extend(toks)
images.extend(imgs)
# Add the start of an assistant message for the model to complete.
tokens.extend(self._encode_header("assistant"))
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
content = content.strip(" ")
header_str = self.possible_headers[Role.assistant]
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom|>"):
content = content[: -len("<|eom|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
content = ""
return RawMessage(
role="assistant",
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
return LLMInput(
tokens=tokens,
images=[x.image_tiles for x in images] if len(images) > 0 else None,
)

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,313 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from io import BytesIO
from pathlib import Path
from typing import List
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
from llama_stack.models.llama.prompt_format import (
Llama4UseCase,
TextCompletionContent,
UseCase,
)
THIS_DIR = Path(__file__).parent
def usecases(base_model: bool = False) -> List[UseCase | str]:
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
img_small_dog = f.read()
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
img_dog = f.read()
with open(THIS_DIR.parent / "resources/pasta.jpeg", "rb") as f:
img_pasta = f.read()
out = []
out.extend(
[
textwrap.dedent(
"""
# Llama 4 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 4:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|header_start|>` and `<|header_end|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user and assistant].
- `<|eot|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|image_start|>` and `<|image_end|>`: These tokens enclose the image data in the prompt.
- `<|patch|>`: This token represents a piece of the tile/
- `<|tile_y_separator|>` and `<|tile_x_separator|>`: These tokens are used to separate the y and x tiles of an image
- `<|image|>`: In the new architecture, this token now separates the regular sized image information from a downsized version of it that fits in a single tile. The longer side is used for calculating the scale factor and the rest is padded to fit the tile.
"""
),
textwrap.dedent(
"""
There are 3 different roles that are supported by Llama 4
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
"""
),
]
)
if base_model:
out.extend(
[
"# Llama 4 Base Model",
Llama4UseCase(
title="Text completion - Paris information",
description="Text completion for Llama 4 base model uses this format.",
dialogs=[TextCompletionContent(content="The capital of France is Paris")],
),
Llama4UseCase(
title="Text completion - The color of the sky",
description="Text completion for Llama 4 base model uses this format.",
dialogs=[
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be")
],
notes="",
),
Llama4UseCase(
title="Text completion - Translation example",
description="Text completion for Llama 4 base model uses this format.",
dialogs=[
TextCompletionContent(
content="""apple is pomme,
bannana is banane,
cherry is"""
)
],
notes="",
),
]
)
out.extend(
[
"# Llama 4 Instruct Model",
Llama4UseCase(
title="Simple User and assistant conversation",
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
dialogs=[
[
RawMessage(role="system", content="You are a helpful assistant"),
RawMessage(
role="user",
content="Answer who are you in the form of jeopardy?",
),
]
],
notes="",
max_gen_len=512,
),
"# Image prompt format",
Llama4UseCase(
title="Single image prompt format - small image",
description="This example passes an image that is smaller than the tile size, to show the tile separator tokens are not needed",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=BytesIO(img_small_dog)),
RawTextItem(text="Describe this image in two sentences"),
],
)
]
],
notes="""Notice the structure of the image section:
```
<|image_start|><|image|><|patch|>...<|patch|><|image_end|>
```
This is due to the image being smaller than the tile size.
""",
max_gen_len=512,
),
Llama4UseCase(
title="Single image prompt format",
description="Here is an example of how to pass an image to the model",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=BytesIO(img_dog)),
RawTextItem(text="Describe this image in two sentences"),
],
)
]
],
notes="""With a bigger image, the image will include the tile separator tokens. Additionally, the image tag now separates a scaled down version of the image from the regular sized image.
```
<|image_start|><|patch|>...<|patch|><|tile_x_separator|><|patch|>...<|patch|><|tile_y_separator|><|patch|>...<|patch|><|image|><|patch|>...<|patch|><|image_end|>
```
""",
max_gen_len=1024,
),
Llama4UseCase(
title="Multiple images prompt format",
description="Here is an example of how to pass an image to the model",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=BytesIO(img_dog)),
RawMediaItem(data=BytesIO(img_pasta)),
RawTextItem(text="Describe these images in two sentences"),
],
)
]
],
notes="With multiple images, each one is encapsulated in their corresponding image tags.",
max_gen_len=4096,
),
"# Tool calling\nWe are continuing the format for zero shot function calling used in previous versions of Llama. All available functions can be provided either in the system message or in the user message.",
Llama4UseCase(
title="Zero shot function calling - system message",
dialogs=[
[
RawMessage(
role="system",
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
""",
),
RawMessage(
role="user",
content="What is the weather in SF and Seattle?",
),
]
],
notes=textwrap.dedent(
"""
- The output supports multiple, and parallel tool calls natively
- JSON format for defining the functions in the system prompt is similar to Llama3.1
"""
),
),
Llama4UseCase(
title="Zero shot function calling - user message",
description=textwrap.dedent(
"""
Similar to the above example, you can also provide information for all the available tools in the user message.
"""
),
dialogs=[
[
RawMessage(
role="user",
content="""Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
Here is a list of functions in JSON format that you can invoke:
[
{
"name": "get_user_info",
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
"parameters": {
"type": "dict",
"required": [
"user_id"
],
"properties": {
"user_id": {
"type": "integer",
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
},
"special": {
"type": "string",
"description": "Any special information or parameters that need to be considered while fetching user details.",
"default": "none"
}
}
}
}
]
Should you decide to return the function call(s), put them in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
You SHOULD NOT include any other text in the response.""",
),
]
],
notes=textwrap.dedent(
"""
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
"""
),
),
Llama4UseCase(
title="Tool calling with custom formats",
description=textwrap.dedent(
"""
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
"""
),
dialogs=[
[
RawMessage(
role="user",
content="""You have access to the following functions:\nUse the function 'trending_songs' to 'Returns the trending songs on a Music site':\n{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}\n\nThink very carefully before calling functions.\nIf you choose to call a function ONLY reply in the following format with no prefix or suffix:\n\n<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line<|eot_id|>""",
),
RawMessage(
role="user",
content="Use tools to get latest trending songs",
),
]
],
),
]
)
return out

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,255 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None
def get_reserved_special_tokens(name, count, start_index=0):
return [f"<|{name}_reserved_special_token_{i}|>" for i in range(start_index, start_index + count)]
# 200005, ..., 200079
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|header_start|>",
"<|header_end|>",
"<|eom|>",
"<|eot|>",
"<|step|>",
"<|text_post_train_reserved_special_token_0|>",
"<|text_post_train_reserved_special_token_1|>",
"<|text_post_train_reserved_special_token_2|>",
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
"<|python_start|>",
"<|python_end|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 6
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
# 200080, ..., 201133
LLAMA4_VISION_SPECIAL_TOKENS = [
"<|image_start|>",
"<|image_end|>",
"<|vision_reserved_special_token_0|>",
"<|vision_reserved_special_token_1|>",
"<|tile_x_separator|>",
"<|tile_y_separator|>",
"<|vision_reserved_special_token_2|>",
"<|vision_reserved_special_token_3|>",
"<|vision_reserved_special_token_4|>",
"<|vision_reserved_special_token_5|>",
"<|image|>",
"<|vision_reserved_special_token_6|>",
"<|patch|>",
] + get_reserved_special_tokens(
"vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|fim_prefix|>",
"<|fim_middle|>",
"<|fim_suffix|>",
]
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 2048
O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501
@classmethod
def get_instance(cls):
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
return _INSTANCE
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
assert len(set(special_tokens)) == len(special_tokens)
assert len(special_tokens) <= self.num_reserved_special_tokens
reserved_tokens = [
f"<|reserved_special_token_{i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.O200K_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
self.n_words: int = num_base_tokens + len(special_tokens)
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad|>"]
self.eot_id: int = self.special_tokens["<|eot|>"]
self.eom_id: int = self.special_tokens["<|eom|>"]
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom|>"],
self.special_tokens["<|eot|>"],
]
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_special ("all"|set[str]): allowed special tokens in string
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
if allowed_special is None:
allowed_special = set()
assert type(s) is str
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]

View file

@ -27,6 +27,10 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from .llama3.interface import LLama31Interface
from .llama3.template_data import (
@ -46,6 +50,7 @@ class UseCase(BaseModel):
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
notes: str = ""
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
max_gen_len: int = 512
def md_format(self):
section = textwrap.dedent(
@ -75,17 +80,16 @@ class UseCase(BaseModel):
elif isinstance(dialog, TextCompletionContent):
input_tokens, output_tokens = generator.text_completion_raw(
dialog.content,
max_gen_len=64,
temperature=0.1,
top_p=0.95,
max_gen_len=64,
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
max_gen_len=512,
temperature=0.0,
top_p=0.95,
tool_prompt_format=self.tool_prompt_format,
max_gen_len=self.max_gen_len,
)
text += "##### Input Prompt Format\n"
@ -115,6 +119,45 @@ class UseCase(BaseModel):
return section
class Llama4UseCase(UseCase):
def dialogs_to_text(self, generator) -> str:
def _code_block(text):
return f"```\n{text}\n```"
text = ""
tokenizer = Tokenizer.get_instance()
temperature = 0.0
for dialog in self.dialogs:
if isinstance(dialog, str):
text += dialog
text += "\n\n"
continue
elif isinstance(dialog, TextCompletionContent):
# TODO pass the raw input and do the encoding in the text completion function
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
llm_input = LLMInput(tokens=input_tokens)
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
temperature=temperature,
max_gen_len=self.max_gen_len,
)
text += "##### Input Prompt Format\n"
text += _code_block(tokenizer.decode(input_tokens))
text += "\n\n"
text += "##### Model Response Format\n"
text += _code_block(tokenizer.decode(output_tokens))
text += "\n\n"
return text
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
interface = LLama31Interface(tool_prompt_format)

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 438 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View file

@ -19,6 +19,7 @@ from .datatypes import (
CheckpointQuantizationFormat,
CoreModelId,
Model,
ModelFamily,
SamplingParams,
TopPSamplingStrategy,
)
@ -36,7 +37,13 @@ def resolve_model(descriptor: str) -> Optional[Model]:
def all_registered_models() -> List[Model]:
return (
llama2_family() + llama3_family() + llama3_1_family() + llama3_2_family() + llama3_3_family() + safety_models()
llama2_family()
+ llama3_family()
+ llama3_1_family()
+ llama3_2_family()
+ llama3_3_family()
+ llama4_family()
+ safety_models()
)
@ -83,6 +90,60 @@ def llama3_3_family() -> List[Model]:
]
def llama4_family() -> List[Model]:
return [
*llama4_base_models(),
*llama4_instruct_models(),
]
def llama4_base_models() -> List[Model]:
return [
Model(
core_model_id=CoreModelId.llama4_scout_17b_16e,
description="Llama 4 Scout (17b 16 experts model)",
huggingface_repo="meta-llama/Llama-4-Scout-17B-16E",
pth_file_count=8,
arch_args={},
),
Model(
core_model_id=CoreModelId.llama4_maverick_17b_128e,
description="Llama 4 Maverick (17b 128 experts model)",
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E",
pth_file_count=8,
arch_args={},
),
]
def llama4_instruct_models() -> List[Model]:
return [
Model(
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
description="Llama 4 Scout (17b 16 experts instruct model)",
huggingface_repo="meta-llama/Llama-4-Scout-17B-16E-Instruct",
pth_file_count=8,
arch_args={},
),
Model(
core_model_id=CoreModelId.llama4_maverick_17b_128e_instruct,
description="Llama 4 Maverick (17b 128 experts instruct model)",
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
pth_file_count=8,
arch_args={},
),
Model(
core_model_id=CoreModelId.llama4_maverick_17b_128e_instruct,
description="Llama 4 Maverick (FP8 quantized)",
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
pth_file_count=8,
variant="fp8",
arch_args={},
),
]
def llama2_base_models() -> List[Model]:
return [
Model(
@ -989,12 +1050,24 @@ def llama_meta_pth_size(model: Model) -> int:
if model.core_model_id not in (
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_405b_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
):
return 0
if model.pth_file_count == 16:
return 51268302389
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
return 60903742309
else:
return 101470976045
if model.model_family == ModelFamily.llama3_1:
if model.pth_file_count == 16:
return 51268302389
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
return 60903742309
else:
return 101470976045
if model.model_family == ModelFamily.llama4:
if model.core_model_id == CoreModelId.llama4_maverick_17b_128e:
return 100458118386
elif model.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
return 54121549657
else:
return 100426653046

View file

@ -255,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
input_messages = last_turn_messages
input_messages = last_turn.input_messages
turn_id = request.turn_id
start_time = last_turn.started_at

View file

@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
from typing import Generator, List, Optional, Tuple
import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ResponseFormat,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
get_default_tool_prompt_format,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .inference import resolve_model
from .llama3.generation import Llama3
from .llama4.generation import Llama4
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.json_schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,
tokenizer.stop_tokens,
)
token_enforcer = TokenEnforcer(data, parser)
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def _infer_sampling_params(sampling_params: SamplingParams):
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
temperature = 0.0
top_p = 1.0
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
temperature = sampling_params.strategy.temperature or 1.0
top_p = sampling_params.strategy.top_p or 1.0
else:
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
return temperature, top_p
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
tool_config = request.tool_config
if tool_config is not None and tool_config.tool_prompt_format is not None:
return tool_config.tool_prompt_format
else:
return get_default_tool_prompt_format(request.model)
class Llama4Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed"
elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed"
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama4.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_input=self.formatter.encode_content(request.content),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
model_id: str,
llama_model: Model,
):
self.inner_generator = Llama3.build(
config=config,
model_id=model_id,
llama_model=llama_model,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
model_input=self.formatter.encode_content(request.content),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)

View file

@ -34,11 +34,16 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ModelFamily,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@ -55,7 +60,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .llama3.generation import Llama3
from .generators import Llama3Generator, Llama4Generator
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -64,6 +69,14 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
return Llama3Generator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
@ -77,29 +90,10 @@ class MetaReferenceInferenceImpl(
async def initialize(self) -> None:
pass
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama3.build(self.config, model_id, llama_model)
self.model_id = model_id
self.llama_model = llama_model
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def unregister_model(self, model_id: str) -> None:
pass
@ -127,11 +121,57 @@ class MetaReferenceInferenceImpl(
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata
# kill this madness
if "skip_load" in model.metadata and model.metadata["skip_load"]:
return model
await self.load_model(model.identifier, llama_model)
return model
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=llama_model.pth_file_count,
builder_fn=builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
if llama_model.model_family == ModelFamily.llama4
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
),
)
self.generator.start()
else:
self.generator = builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def completion(
self,
model_id: str,
@ -164,14 +204,16 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
@ -205,6 +247,8 @@ class MetaReferenceInferenceImpl(
yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
tokenizer = self.generator.formatter.tokenizer
def impl():
tokens = []
logprobs = []
@ -212,9 +256,9 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
if request.logprobs:
@ -225,11 +269,9 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens)
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
return CompletionResponse(
content=content,
stop_reason=stop_reason,
@ -288,6 +330,8 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
tokenizer = self.generator.formatter.tokenizer
def impl():
tokens = []
logprobs = []
@ -296,9 +340,9 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
if request.logprobs:
@ -326,6 +370,8 @@ class MetaReferenceInferenceImpl(
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -355,10 +401,10 @@ class MetaReferenceInferenceImpl(
)
continue
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:

View file

@ -4,17 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import math
import os
import sys
import time
from pathlib import Path
from typing import Generator, List, Optional, Tuple, Union
from typing import Callable, Generator, Optional, Union
import torch
import torch.nn.functional as F
@ -23,27 +19,16 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
ResponseFormat,
ResponseFormatType,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from ..common import TokenResult, model_checkpoint_dir
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
@ -51,7 +36,7 @@ from .args import ModelArgs
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
log = logging.getLogger(__name__)
log = get_logger(__name__, category="inference")
class Llama3:
@ -146,7 +131,7 @@ class Llama3:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
from ..quantization.loader import convert_to_fp8_quantized_model
from .quantization.loader import convert_to_fp8_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
@ -159,7 +144,7 @@ class Llama3:
model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig):
from ..quantization.loader import convert_to_int4_quantized_model
from .quantization.loader import convert_to_int4_quantized_model
model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config)
@ -169,7 +154,7 @@ class Llama3:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from ..quantization.hadamard_utils import (
from ..hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
@ -222,9 +207,8 @@ class Llama3:
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
print_input_tokens: bool = False,
logits_processor: Optional["LogitsProcessor"] = None,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
params = self.model.params
@ -292,7 +276,7 @@ class Llama3:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@ -336,58 +320,6 @@ class Llama3:
if all(eos_reached):
break
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content)
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
request.messages,
request.tool_config.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
def sample_top_p(probs, p):
"""
@ -412,72 +344,3 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.json_schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,
tokenizer.stop_tokens,
)
token_enforcer = TokenEnforcer(data, parser)
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def _infer_sampling_params(sampling_params: SamplingParams):
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
temperature = 0.0
top_p = 1.0
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
temperature = sampling_params.strategy.temperature
top_p = sampling_params.strategy.top_p
else:
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
return temperature, top_p

View file

@ -7,9 +7,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
# type: ignore
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -19,22 +19,27 @@ from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
from ...llama3.args import ModelArgs
from ...llama3.model import Transformer, TransformerBlock
from ..config import MetaReferenceQuantizedInferenceConfig
from ...config import MetaReferenceQuantizedInferenceConfig
from ..args import ModelArgs
from ..model import Transformer, TransformerBlock
log = logging.getLogger(__name__)
log = get_logger(__name__, category="quantization")
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
@ -51,8 +56,7 @@ def convert_to_fp8_quantized_model(
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
assert config.model is not None, "Model must be specified for quantized inference"
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
@ -82,7 +86,7 @@ def convert_to_fp8_quantized_model(
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
@ -136,6 +140,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
@ -143,9 +149,6 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
else:
self.adaptor = None
self.lora_scale = None
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
@ -293,10 +296,10 @@ def convert_to_int4_quantized_model(
) -> Transformer:
"""Convert the model to int4 quantized model."""
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
@ -317,4 +320,4 @@ def convert_to_int4_quantized_model(
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)
return cast(Transformer, model.to(device))

View file

@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel, model_validator
class QuantizationScheme(Enum):
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
class LoRAArgs(BaseModel):
rank: int
scale: float
class MoEArgs(BaseModel):
num_experts: int = -1
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
auto_scale_F: bool = ( # noqa: N815
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
)
top_k: int = 1
interleave_moe_layer_step: int = 1
class Size(BaseModel):
height: int
width: int
class VisionArgs(BaseModel):
image_size: Size
patch_size: Size
# parameters for the encoder transformer
dim: int
n_layers: int
n_heads: int
mlp_ratio: float
output_dim: int
pixel_shuffle_ratio: float
class ModelArgs(BaseModel):
dim: int = -1
n_layers: int = -1
n_heads: int = -1
n_kv_heads: Optional[int] = None
head_dim: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_exp: Optional[float] = None
norm_eps: float = 1e-5
attention_chunk_size: Optional[int] = None
rope_theta: float = 500000
use_scaled_rope: bool = False
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
attn_temperature_tuning: bool = False
floor_scale: float = 8192.0
attn_scale: float = 0.1
vision_args: Optional[VisionArgs] = None
moe_args: Optional[MoEArgs] = None
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
max_batch_size: int = 32
max_seq_len: int = 2048
@model_validator(mode="after")
def validate(self) -> "ModelArgs":
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
assert self.n_heads % self.n_kv_heads == 0, (
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
return self

View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
@dataclass
class MaskedEmbedding:
embedding: torch.Tensor
mask: torch.Tensor
@dataclass
class LLMInput:
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
tokens: torch.Tensor
# images are already pre-processed (resized, tiled, etc.)
images: Optional[List[torch.Tensor]] = None
@dataclass
class TransformerInput:
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
tokens: torch.Tensor
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: Union[torch.Tensor, int]
image_embedding: Optional[MaskedEmbedding] = None
@dataclass
class LLMOutput:
logits: torch.Tensor
TransformerOutput = LLMOutput

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from typing import Any, Dict, List
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import nn
from torch.nn import functional as F
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
do_reduce: bool = True,
):
super().__init__()
self.do_reduce = do_reduce
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
state_dict[prefix + "w1.weight"] = w1
state_dict[prefix + "w3.weight"] = w3
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
def forward(self, x):
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
out = F.linear(x, self.w2.weight)
if self.do_reduce:
return reduce_from_model_parallel_region(out)
return out

View file

@ -0,0 +1,330 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import codecs
import io
import json
import os
import sys
import time
from enum import Enum
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from llama_stack.models.llama.llama4.chat_format import (
ChatFormat,
RawContent,
RawMessage,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from ..common import TokenResult
from .args import ModelArgs
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
class Llama4:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[str] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
**params,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
tokenizer = Tokenizer.get_instance()
# TODO: params.json should always have correct vocab_size
if model_args.vocab_size == -1:
model_args.vocab_size = tokenizer.n_words
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2))
ckpt_path = checkpoints[get_model_parallel_rank()]
print(f"Loading checkpoint from {ckpt_dir}...")
with open(ckpt_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama4(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
@torch.inference_mode()
def generate(
self,
llm_input: LLMInput,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input and get_model_parallel_rank() == 0:
tokens_to_print = list(llm_input.tokens)
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [llm_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(llm_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
image_embedding = None
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [llm_input.images]
image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask,
)
xformer_input = TransformerInput(
tokens=tokens[:, prev_pos:cur_pos],
tokens_position=prev_pos,
image_embedding=image_embedding,
)
xformer_output = self.model.forward(xformer_input)
logits = xformer_output.logits
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
prev_pos = cur_pos
if all(eos_reached):
break
def completion(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
llm_input = self.formatter.encode_content(content)
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
if result.token in self.tokenizer.stop_tokens:
break
yield result
def chat_completion(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
llm_input = self.formatter.encode_dialog_prompt(messages)
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
if result.token in self.tokenizer.stop_tokens:
break
yield result
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
):
llm_input = self.formatter.encode_dialog_prompt(messages)
output_tokens = []
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
):
output_tokens.append(result.token)
return llm_input.tokens, output_tokens
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,453 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Any, Dict, List, Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn
from .args import ModelArgs
from .datatypes import TransformerInput, TransformerOutput
from .ffn import FeedForward
from .moe import MoE
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class L2Norm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
# TODO: this module needs to be moved into a separate file since it can be used by
# the vision encoder as well.
def __init__(
self,
args: ModelArgs,
use_qk_norm: bool,
use_rope: bool,
add_bias: bool = False,
):
super().__init__()
self.use_rope = use_rope
self.use_qk_norm = use_qk_norm
# For attention temperature tuning
self.attn_temperature_tuning = args.attn_temperature_tuning
self.floor_scale = args.floor_scale
self.attn_scale = args.attn_scale
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=add_bias,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
if r != 0:
raise ValueError(
f"shape={tuple(wqkv.shape)} is not divisible by "
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
)
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
state_dict[prefix + "wq.weight"] = wq
state_dict[prefix + "wk.weight"] = wk
state_dict[prefix + "wv.weight"] = wv
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
if self.use_rope:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm:
xq = self.qk_norm(xq)
xk = self.qk_norm(xk)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
if self.attn_temperature_tuning and not self.use_rope:
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
attn_scales = attn_scales.view(1, seqlen, 1, 1)
xq = xq * attn_scales
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
xk = self.cache_k[:bsz, : start_pos + seqlen]
xv = self.cache_v[:bsz, : start_pos + seqlen]
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
xk = xk.repeat_interleave(self.n_rep, dim=1)
xv = xv.repeat_interleave(self.n_rep, dim=1)
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(attn_output)
return output
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
use_rope = not self.is_nope_layer
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
self.feed_forward = MoE(
dim=args.dim,
hidden_dim=int(args.ffn_exp * args.dim),
ffn_dim_multiplier=args.ffn_dim_multiplier,
multiple_of=args.multiple_of,
moe_args=args.moe_args,
)
else:
hidden_dim = int(4 * args.dim)
hidden_dim = int(2 * hidden_dim / 3)
if args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=hidden_dim,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
elif prefix + "feed_forward.norm.weight" in state_dict:
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
for k in (
"feed_forward.experts.mlp",
"feed_forward.mlp_shared",
"attention.wo",
"attention.wqkv",
):
if prefix + k + "._extra_state" in state_dict:
state_dict.pop(prefix + k + "._extra_state")
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor],
local_attn_mask: Optional[torch.Tensor],
):
# The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used
if self.is_nope_layer or local_attn_mask is None:
mask = global_attn_mask
else:
mask = local_attn_mask
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs, **kwargs) -> None:
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(TransformerBlock(layer_id, args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
self.freqs_cis = precompute_freqs_cis(
args.dim // args.n_heads,
args.max_seq_len * 2,
args.rope_theta,
args.use_scaled_rope,
)
vision_args = self.args.vision_args
if vision_args:
# circular import otherwise until we refactor out Attention
from .vision.embedding import VisionEmbeddings
self.vision_embeddings = VisionEmbeddings(vision_args)
self.vision_projection = ColumnParallelLinear(
vision_args.output_dim,
args.dim,
bias=False,
init_method=lambda x: x,
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs")
@torch.inference_mode()
def forward(self, model_input: TransformerInput) -> TransformerOutput:
tokens = model_input.tokens
start_pos = model_input.tokens_position
assert isinstance(start_pos, int), (
"This implementation does not support different start positions per batch item"
)
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if image_embedding := model_input.image_embedding:
h_image = self.vision_projection(image_embedding.embedding)
h = h * ~image_embedding.mask + h_image * image_embedding.mask
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
global_attn_mask, local_attn_mask = None, None
if seqlen > 1:
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
# https://github.com/pytorch/pytorch/issues/100005
# torch.triu is buggy when the device is mps: filled values are
# nan instead of 0.
if global_attn_mask.device.type == torch.device("mps").type:
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
if chunk_size := self.args.attention_chunk_size:
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
h = self.norm(h)
output = self.output(h).float()
return TransformerOutput(logits=output)
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
# in the iRoPE architecture
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
block_pos = torch.abs(
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
)
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0)
return mask.to(device)

View file

@ -0,0 +1,224 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N806
# pyre-strict
from typing import Any, Dict, List
import fairscale.nn.model_parallel.initialize as fs_init
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import Tensor, nn
from torch.nn import functional as F
from .args import MoEArgs
from .ffn import FeedForward
class Experts(nn.Module):
def __init__(
self,
num_local_experts: int,
dim: int,
hidden_dim: int,
) -> None:
super().__init__()
dtype = torch.get_default_dtype()
self.num_local_experts = num_local_experts
self.dim = dim
divide_factor = fs_init.get_model_parallel_world_size()
self.w1: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
)
self.w2: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
divide_exact(hidden_dim, divide_factor),
dim,
dtype=dtype,
)
)
self.w3: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
self.prefix = prefix
if prefix + "moe_w_in_eD_F" in state_dict:
e = self.num_local_experts
D = self.dim
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
def forward(
self,
routed_in_egD: torch.Tensor, # noqa: N803
) -> torch.Tensor:
e = self.num_local_experts
D = self.dim
x_egD = routed_in_egD.view(e, -1, D)
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
out_egD = out_egD.view(-1, D)
return out_egD
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
return torch.bmm(middle_out_egF, w2)
class MoE(torch.nn.Module):
"""
This EC implementation is modified from the original EC module.
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
This module supports 3 sharding methods of the experts:
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include:
- a: bsz*slen
- E: number of experts
- e: number of local experts per ep (n_experts/ep)
- et: number of local experts per tp (n_experts/tp)
- D: hidden dimension
- d: D/tp
- F: model dimension
- f: F/tp (used in column/row-parallel linear)
- G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP)
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
Examples:
x_aD [a, D]
routed_in_etG_D [et*G, D]
x_eGGD: [e, GG, D]
"""
def __init__(
self,
dim: int,
hidden_dim: int,
ffn_dim_multiplier: float,
multiple_of: int,
moe_args: MoEArgs,
) -> None:
super().__init__()
self.moe_args = moe_args
hidden_dim_denom: float = 1
if moe_args.auto_scale_F:
hidden_dim_denom = moe_args.capacity_factor + 1
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
if moe_args.auto_scale_F:
hidden_dim = int(hidden_dim / hidden_dim_denom)
hidden_dim += -hidden_dim % multiple_of
num_local_experts: int = moe_args.num_experts
dtype: torch.dtype = torch.get_default_dtype()
self.experts = Experts(
num_local_experts,
dim,
hidden_dim,
)
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "w_in_shared_FD.weight" in state_dict:
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
_, slen, D = x_bsD.shape
x_aD = x_bsD.view(-1, D)
a = x_aD.shape[0]
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
router_scores = (
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
.scatter_(1, router_indices_aK, router_scores_aK)
.transpose(0, 1)
)
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
router_scores = torch.sigmoid(router_scores)
routed_in_EG_D: Tensor = torch.gather(
x_aD,
dim=0,
index=router_indices.reshape(-1, 1).expand(-1, D),
)
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD)
routed_out_egg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_(
dim=0,
index=router_indices_EG_D,
src=routed_out_egg_D.view(-1, D),
)
out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D)
def divide_exact(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
return numerator // denominator

View file

@ -0,0 +1,436 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from collections import defaultdict
from typing import Optional, Set, Tuple
import torch
import torchvision.transforms as tv
from PIL import Image, ImageFile
from torchvision.transforms import functional as F
ImageFile.LOAD_TRUNCATED_IMAGES = True
IMAGE_RES = 448
class ResizeNormalizeImageTransform:
def __init__(
self,
size_width=None,
size_height=None,
) -> None:
self._size_width = size_width or IMAGE_RES
self._size_height = size_height or IMAGE_RES
self._mean = (0.5, 0.5, 0.5)
self._std = (0.5, 0.5, 0.5)
self.tv_transform = tv.Compose(
[
tv.Resize((self._size_height, self._size_width)),
tv.ToTensor(),
tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
),
]
)
def __call__(self, image: Image.Image) -> torch.Tensor:
return self.tv_transform(image)
class VariableSizeImageTransform(object):
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
The algorithm will NOT distort the image fit a certain aspect ratio, because
that leads to a significant degradation in image quality.
It can be summarized in 6 steps:
1. Find all possible canvas combinations of max_num_chunks;
2. Find the best canvas to fit the image;
3. Resize without distortion
4. Pad
5. Normalize
6. Chunk
For example, if an input image is of size 300x800, patch_size of 224,
and max_num_chunks = 8, it will find the closest aspect ratio that
is allowed within 8 image chunks, with some restrictions.
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
giving a total of 8 chunks.
If resize_to_max_canvas, the image will be resized (without distortion),
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
where we maintain the original aspect ratio and pad with zeros value for the rest.
This approach minimizes the amount of padding required for any arbitrary resolution.
However, if limit_upscaling_to_patch_size is set to True,
the upscaling will be limited to the patch size. In the example above,
the image would remain 300x800 (no upscaling), and then padded to 448:896.
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
patches are coming from the resizing and chunking.
"""
def __init__(self, size: int = IMAGE_RES) -> None:
self.size = size
self.to_tensor = tv.ToTensor()
self._mean = (0.5, 0.5, 0.5)
self._std = (0.5, 0.5, 0.5)
self.normalize = tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
)
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
n (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
factors_set.add(i)
factors_set.add(n // i)
return factors_set
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(self.get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, width in value:
possible_resolutions.append((height * patch_size, width * patch_size))
return possible_resolutions
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_width, original_height = image_size
target_width, target_height = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_width, new_height
def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_width, new_height = target_size
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
new_im.paste(image)
return new_im
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(1, 3, 0, 2, 4).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
If target_size requires upscaling the image, the user can set max_upscaling_size to
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
modifying target_size works as a boundary for the image's largest side.
Args:
resample (str): Resampling method used when resizing images.
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
max_upscaling_size (int): The maximum size to upscale the image to.
If None, there is no limit.
Examples:
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(600, 300) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (2000, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 100) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 2000
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = None
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
"""
image_width, image_height = image.size
image_size = (image_width, image_height)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
if max_upscaling_size is not None:
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
target_size = (new_target_width, new_target_height)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
image = F.resize(
image,
(
max(new_size_without_distortion[1], 1),
max(new_size_without_distortion[0], 1),
),
interpolation=self.resample,
)
return image
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_width, original_height = image_size
# get all possible resolutions heights/widths
target_widths, target_heights = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
def __call__(
self,
image: Image.Image,
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Args:
image (PIL.Image): Image to be resized.
max_num_chunks (int): Maximum number of chunks to split the image into.
normalize_img (bool): Whether to normalize the image.
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
"""
assert max_num_chunks > 0
assert isinstance(image, Image.Image), type(image)
w, h = image.size
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
possible_resolutions = torch.tensor(possible_resolutions)
best_resolution = self.get_best_fit(
image_size=(w, h),
possible_resolutions=possible_resolutions,
resize_to_max_canvas=resize_to_max_canvas,
)
max_upscaling_size = None if resize_to_max_canvas else self.size
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
image = self._pad(image, best_resolution)
image = self.to_tensor(image)
if normalize_img:
image = self.normalize(image)
ratio_w, ratio_h = (
best_resolution[0] // self.size,
best_resolution[1] // self.size,
)
image = self._split(image, ratio_w, ratio_h) # type: ignore
ar = (ratio_h, ratio_w)
return image, ar

View file

@ -0,0 +1,207 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from typing import Optional
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor
from torch.nn import functional as F
from ..generation import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
def experts_batched_swiglu_wrapper(
self,
x: Tensor, # (e, g, D)
w1: Tensor, # (e, D, F)
w3: Tensor, # (e, D, F)
w2: Tensor, # (e, F, D)
) -> torch.Tensor:
from ...quantize_impls import bmm_nt
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
return bmm_nt(middle_out_egF, w2)
def convert_to_quantized_model(
model: Transformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
use_rich_progress: bool = True,
) -> Transformer:
from ...quantize_impls import (
Fp8ScaledWeights,
Int4ScaledWeights,
load_fp8,
load_int4,
quantize_fp8,
quantize_int4,
)
rank = get_model_parallel_rank()
use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True)
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
def apply_quantization(key, weight):
scale = int4_scales[key]
zero_point = int4_zero_points[key]
return load_int4(
weight,
scale,
zero_point,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
else:
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path):
log_status(f"Rank {rank}: Loading fp8 scales")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
def apply_quantization(key, weight):
scale = fp8_scales[key]
return load_fp8(
weight,
scale,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
else:
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
def apply_quantization(_, weight):
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
processed_blocks = 0
try:
if use_rich_progress:
progress.start()
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
# Skip quantization on first and last layers
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
# Skip quantization on dense layers
if not isinstance(block.feed_forward, MoE):
continue
update_status(f"Rank {rank} - Layer {block.layer_id}")
# Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
)
processed_blocks += 1
update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA")
param_count = 0
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
parameter.data = parameter.to(device="cuda")
param_count += 1
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
finally:
if use_rich_progress:
progress.stop()
return model
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
console = None
if use_rich_progress:
from rich.console import Console
console = Console(highlight=False)
def log_status(message: str) -> None:
if use_rich_progress:
console.print(message)
elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message)
total_blocks = sum(
1
for _, block in model.named_modules()
if (
isinstance(block, TransformerBlock)
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
and isinstance(block.feed_forward, MoE)
)
)
progress = None
if use_rich_progress:
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
progress = Progress(
SpinnerColumn(),
BarColumn(complete_style="green", finished_style="bright_green"),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TextColumn("ETA:"),
TimeRemainingColumn(),
TextColumn("[bold]{task.fields[status]}"),
console=console,
expand=True,
)
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
if use_rich_progress:
if message is not None:
progress.update(task_id, status=message)
if completed is not None:
progress.update(task_id, completed=completed)
elif rank == 0 and completed and completed % 10 == 0:
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
return progress, log_status, update_status

View file

@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from ..args import VisionArgs
from .encoder import VisionEncoder
class PixelShuffle(nn.Module):
def __init__(self, ps_ratio):
super().__init__()
self.ps_ratio = ps_ratio
def forward(self, x):
# x: [B, N, C], N = number of patches
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
hh = ww = int(math.sqrt(x.shape[1]))
x = x.reshape(x.shape[0], hh, ww, -1)
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
return pixel_shuffle_patches
def pixel_shuffle_op(input_x, ps_ratio):
n, w, h, c = input_x.size()
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
input_x = input_x.permute(0, 2, 1, 3).contiguous()
input_x = input_x.view(
n,
int(h * ps_ratio),
int(w * ps_ratio),
int(c / (ps_ratio * ps_ratio)),
)
input_x = input_x.permute(0, 2, 1, 3).contiguous()
return input_x
class SimpleMLP(torch.nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
bias: bool = True,
dropout: float = 0.0,
act_layer: Callable = nn.GELU,
):
super().__init__()
# layers
self.c_fc = ColumnParallelLinear(
dim,
hidden_dim,
bias=bias,
gather_output=False,
)
self.c_proj = RowParallelLinear(
hidden_dim,
hidden_dim,
bias=bias,
input_is_parallel=True,
)
self.non_linearity = act_layer()
self.dropout = dropout
def forward(self, x):
hidden = self.c_fc(x)
hidden = self.non_linearity(hidden)
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
return self.non_linearity(self.c_proj(hidden))
class PixelShuffleMLP(torch.nn.Module):
def __init__(
self,
ps_ratio: float,
input_dim: int,
output_dim: int = 4096,
add_fc: bool = False,
):
super().__init__()
self.pixel_shuffle = PixelShuffle(ps_ratio)
self.mlp = SimpleMLP(
int(input_dim // (ps_ratio**2)),
output_dim,
bias=False,
dropout=0.0,
act_layer=nn.GELU,
)
self.fc = nn.Identity()
if add_fc:
self.fc = ColumnParallelLinear(
output_dim,
output_dim,
bias=False,
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = self.pixel_shuffle(encoded_patches)
return self.fc(self.mlp(encoded_patches))
class VisionEmbeddings(torch.nn.Module):
def __init__(self, args: VisionArgs):
super().__init__()
self.args = args
image_size = args.image_size
patch_size = args.patch_size
self.vision_encoder = VisionEncoder(
image_size=(image_size.height, image_size.width),
patch_size=(patch_size.height, patch_size.width),
dim=args.dim,
layers=args.n_layers,
heads=args.n_heads,
mlp_ratio=args.mlp_ratio,
)
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
self.vision_adapter = PixelShuffleMLP(
ps_ratio=args.pixel_shuffle_ratio,
input_dim=args.dim,
output_dim=args.output_dim,
)
self.output_dim = args.output_dim
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
return_state_dict: bool = False,
) -> None:
original_sd = self.state_dict()
for k in state_dict:
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
def _get_empty_sequence(self, h):
return torch.zeros(
h.shape[0],
h.shape[1],
self.output_dim,
device=h.device,
dtype=h.dtype,
)
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
# each image is a tensor of shape [num_tiles, C, H, W]
def forward(
self,
image_batch: List[List[torch.Tensor]],
image_mask: torch.Tensor,
h_ref: torch.Tensor,
) -> torch.Tensor:
images_flattened = [image for sample in image_batch for image in sample]
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
embedding = self.vision_encoder(images_flattened)
projected_embedding = self.vision_adapter(embedding)
h_image = self._get_empty_sequence(h_ref)
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
assert not torch.isnan(encoded_patches_proj).any()
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
)
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
for index in range(h_image.size(0)):
encoded_patches_per_sample = encoded_patches_list[index]
sample_image_mask = image_mask[index]
if encoded_patches_per_sample.numel() == 0:
continue
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
-1, encoded_patches_per_sample.size(-1)
)
n_tokens_to_fill = sample_image_mask.sum()
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
h_image[index].masked_scatter_(
sample_image_mask.expand(-1, h_image.size(-1)),
encoded_patches_per_sample[:n_tokens_to_fill],
)
return h_image

View file

@ -0,0 +1,411 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from torch import einsum
from ..args import ModelArgs
from ..model import Attention
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, height, width)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x = self._linear(x)
return x
class _FeedForward(torch.nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float,
act_layer: Callable = nn.GELU,
):
super().__init__()
# layers
self.c_fc = ColumnParallelLinear(
dim,
hidden_dim,
bias=True,
gather_output=False,
init_method=lambda x: x,
)
self.c_proj = RowParallelLinear(
hidden_dim,
dim,
bias=True,
input_is_parallel=True,
init_method=lambda x: x,
)
self.non_linearity = act_layer()
self.dropout = dropout
def forward(self, x):
hidden = self.c_fc(x)
hidden = self.non_linearity(hidden)
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
return self.c_proj(hidden)
class _TransformerBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
gated: bool = False,
):
super().__init__()
assert d_model % n_head == 0
self.n_heads = n_head
self.head_dim = d_model // self.n_heads
attn_args = ModelArgs(
dim=d_model,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_heads,
)
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
self.ln_1 = LayerNorm(d_model)
self.mlp = _FeedForward(
dim=d_model,
hidden_dim=int(mlp_ratio * d_model),
dropout=0.0,
act_layer=act_layer,
)
self.ln_2 = LayerNorm(d_model)
self.gated = gated
if gated:
self.gate_attn = nn.Parameter(torch.zeros(1))
self.gate_ffn = nn.Parameter(torch.zeros(1))
def attention(
self,
x: torch.Tensor,
freq_cis: Optional[torch.Tensor] = None,
):
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
freq_cis: Optional[torch.Tensor] = None,
):
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
x = x + _gate_ffn * self.mlp(self.ln_2(x))
return x
class _Transformer(nn.Module):
def __init__(
self,
dim: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
gated: bool = False,
):
super().__init__()
self.resblocks = nn.ModuleList(
[
_TransformerBlock(
d_model=dim,
n_head=heads,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
gated=gated,
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
out = []
for idx, r in enumerate(self.resblocks):
if return_intermediate is not None and idx in return_intermediate:
out.append(x)
x = r(x, mask=mask, freq_cis=freq_cis)
if return_intermediate is not None:
return x, torch.stack(out, dim=-1)
return x
class PackingIndex:
Z = 0 # Z (time) coordinate of the token in the original sample
Y = 1 # Y (height) coordinate of the token in the original sample
X = 2 # X (width) coordinate of the token in the original sample
TIME = 3 # Total number of time units (frames) in the original sample
HEIGHT = 4 # Height of the original sample
WIDTH = 5 # Width of the original sample
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
# Total size of the enum, remember to update this!
NUM_METADATA = 8
# Note: For padding tokens IDX = -1
# For cls tokens, IDX = -2
ID_CLS_TOKEN = -2
ID_PAD_TOKEN = -1
class VisionEncoder(nn.Module):
def __init__(
self,
image_size: Tuple[int, int],
patch_size: Tuple[int, int],
dim: int,
layers: int,
heads: int,
mlp_ratio: float,
in_channels: int = 3,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.grid_size = (
self.image_size[0] // self.patch_size[0],
self.image_size[1] // self.patch_size[1],
)
self.conv1 = ColumnParallelConv2dPatch(
in_channels=in_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = dim**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
self.positional_embedding_vlm = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
)
self.ln_pre = LayerNorm(dim)
self.ln_post = LayerNorm(dim)
self.transformer = _Transformer(
dim,
layers,
heads,
mlp_ratio,
act_layer=nn.GELU,
)
# NOTE: hack for the fixed res
image_h, image_w = self.image_size
patch_h, patch_w = self.patch_size
idx_h, idx_w = image_h // patch_h, image_w // patch_w
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
img_idx = img_idx.reshape(idx_h * idx_w, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
packed_img_idx = torch.empty(
img_idx.shape[0],
img_idx.shape[1],
PackingIndex.NUM_METADATA - 1,
dtype=torch.int32,
)
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
self.packed_img_idx = packed_img_idx # for positional embedding load hook
# compute rope freqs
rope_freq = self.get_rope_freqs(dim // heads // 2)
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
# disable RoPE for padding and cls tokens
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
# compute complex freqs
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
# xlf automatically broadcasts
self.freq_cis = self.freq_cis.squeeze(0)
self.n_heads = heads // fs_init.get_model_parallel_world_size()
self._register_load_state_dict_pre_hook(self.load_hook)
def get_rope_freqs(self, dim, theta=10000):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
return freqs
@torch.amp.autocast("cuda", enabled=False)
def compute_rope_freqs(self, freqs, t):
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
freqs = freqs.repeat_interleave(2, dim=-1)
return freqs
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
raise ValueError(
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
)
batch_size, token_per_image, _ = self.packed_img_idx.shape
# Input points for idx are [x, y, w, h]
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
total_windows, window_size, _ = idx.shape
# Grid values are [-1, 1] and coords are w, h
grid = (
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
)[None, ...]
# In this mode, cls token has no position embedding
if orig_pos_embed is not None:
posemb = (
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
)
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
sample = F.grid_sample(
posemb, grid, padding_mode="zeros"
) # padding tokens / class token will get zero for posemb
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
sample = torch.where(
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
sample,
)
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
if return_state_dict:
return state_dict
def apply_class_embedding(self, x):
x = torch.cat(
[
x,
self.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
return x
def forward(self, images: torch.Tensor) -> torch.Tensor:
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
if images.ndim == 5:
num_concurrent_media = 1
bsz, num_chunks, nch, h, w = images.shape
else:
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
# patch embedding
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
x = self.conv1(x) # shape = [*, width, grid ** 2]
_, ntok, dim = x.shape
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
# apply cls token
x = self.apply_class_embedding(x)
ntok += 1
# apply position embeddings
if self.positional_embedding_vlm is not None:
x = x + self.positional_embedding_vlm.to(x.dtype)
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
x = self.ln_pre(x)
x = x.view(bsz * num_concurrent_media, -1, dim)
freq_cis = self.freq_cis.to(images.device)
tf_output = self.transformer(
x,
freq_cis=freq_cis,
)
int_x = None
if isinstance(tf_output, tuple):
x, int_x = tf_output
else:
x = tf_output
x = self.ln_post(x)
# remove cls token output
x = x[:, :-1, :]
# add and output x + int_x features
if int_x is not None:
int_x = int_x[:, :-1, :, :]
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
x = torch.cat([x, int_x], dim=-1)
return x

View file

@ -4,23 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from copy import deepcopy
from functools import partial
from typing import Any, Generator
from typing import Any, Callable, Generator
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig
from .llama3.generation import Llama3
from .parallel_utils import ModelParallelProcessGroup
@ -39,11 +33,10 @@ class ModelRunner:
def init_model_cb(
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
builder_fn: Callable,
params: list[Any],
):
llama = Llama3.build(config, model_id, llama_model)
llama = builder_fn(*params)
return ModelRunner(llama)
@ -60,25 +53,15 @@ class LlamaModelParallelGenerator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
model_parallel_size: int,
builder_fn: Callable,
builder_params: list[Any],
formatter: Llama3ChatFormat | Llama4ChatFormat,
):
self.config = config
self.model_id = model_id
self.llama_model = llama_model
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
checkpoint_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
self.model_parallel_size = model_parallel_size
self.builder_fn = builder_fn
self.builder_params = builder_params
self.formatter = formatter
def start(self):
self.__enter__()
@ -87,11 +70,9 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None)
def __enter__(self):
model_parallel_size = self.llama_model.pth_file_count
self.group = ModelParallelProcessGroup(
model_parallel_size,
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
self.model_parallel_size,
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
)
self.group.start()
return self

View file

@ -1,177 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
import logging
from typing import Optional, Type
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
from torch import Tensor, nn
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
def ffn_swiglu(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(
x: Tensor,
w: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y
def ffn_swiglu_fp8_dynamic(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
(B, T, D) = x.shape # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_fp8_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
return z_.view(B, T, D)

View file

@ -1,78 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# The file gets a special treatment for now?
# ruff: noqa: N803
import unittest
import torch
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
from hypothesis import given, settings
from hypothesis import strategies as st
from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):
@settings(deadline=None)
@given(
D=st.sampled_from([4096, 8192]),
HD_L=st.sampled_from([1280, 2560]),
B=st.sampled_from([1, 2]),
T=st.sampled_from([2048, 4096]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_ffn(
self,
D: int, # noqa
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
if __name__ == "__main__":
unittest.main()

View file

@ -1,152 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from torch.nn.parameter import Parameter
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs
from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
quantize_fp8,
)
log = logging.getLogger(__name__)
def main(
ckpt_dir: str,
tokenizer_path: str,
quantized_ckpt_dir: str,
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
""" """
if not os.path.exists(quantized_ckpt_dir):
os.makedirs(quantized_ckpt_dir)
shutil.copy(
os.path.join(ckpt_dir, "params.json"),
os.path.join(quantized_ckpt_dir, "params.json"),
)
shutil.copy(
os.path.join(ckpt_dir, "tokenizer.model"),
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
log.info(ckpt_path)
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(
quantized_ckpt_dir,
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
)
torch.save(model.state_dict(), ckpt_path)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,31 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
set -x
cd $(dirname "$(realpath "$0")")
MASTER_HOST=$1
RUN_ID=$2
CKPT_DIR=$3
QUANT_CKPT_DIR=$4
TOKENIZER_PATH=$5
NNODES=$6
NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-stack" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \
--rdzv_conf='timeout=120' \
--rdzv_backend=c10d \
--rdzv_endpoint="${MASTER_HOST}:29502" \
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR

View file

@ -0,0 +1,332 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# type: ignore
import collections
import logging
from typing import Optional, Tuple, Type, Union
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
except ImportError:
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
raise
import torch
from torch import Tensor, nn
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
class Int4ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Int4Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Int4Weights(
Int4ScaledWeights,
collections.namedtuple(
"Int4Weights",
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
),
):
pass
def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
# Recenter output and move to int8.
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
# Cutlass expects column major layout for scale and zero point,
# so we transpose here and make them contiguous.
scales = scales.view(x.shape[0], -1).t().contiguous()
zeros = zeros.view(x.shape[0], -1).t().contiguous()
return out, scales, zeros
def pack_int4(x: torch.Tensor) -> torch.Tensor:
# Given int8 x, pack adjacent int4 values into a single int8.
low_x = x[:, ::2]
high_x = x[:, 1::2]
# High bits need to left shift, this also masks off extra bits.
high_x = torch.bitwise_left_shift(high_x, 4)
# Low bits need to have sign bits removed.
low_x = torch.bitwise_and(low_x, 0xF)
# Recombine into a single value with bitwise or.
return torch.bitwise_or(low_x, high_x).contiguous()
def bmm_nt(
x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights],
num_tokens: Optional[Tensor] = None,
) -> Tensor:
if isinstance(w, Fp8ScaledWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
elif isinstance(w, Int4ScaledWeights):
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
else:
raise ValueError("Unsupported quantization type")
def ffn_swiglu(
x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights],
w3: Union[Fp8RowwiseWeights, Int4Weights],
w2: Union[Fp8RowwiseWeights, Int4Weights],
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
):
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def quantize_int4(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Quantize [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
if w.ndim >= 3:
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
scale = torch.stack(scale, dim=0)
zero_point = torch.stack(zero_point, dim=0)
else:
wq, scale, zero_point = int4_row_quantize(w)
wq = pack_int4(wq)
del w
return Int4Weights(
weight=wq.to(output_device),
scale=scale.to(output_device),
zero_point=zero_point.to(output_device),
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
scale=w_scale.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_int4(
w: Tensor,
scale: Tensor,
zero_point: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input INT4.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Int4Weights(
weight=w.to(torch.int8).to(device=output_device),
scale=scale.to(device=output_device),
zero_point=zero_point.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_dynamic(
x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights],
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
"""
if isinstance(w, Int4Weights):
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
else:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y
def ffn_swiglu_dynamic(
x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights],
w3: Union[Fp8RowwiseWeights, Int4Weights],
w2: Union[Fp8RowwiseWeights, Int4Weights],
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
assert x.dim() == 3 or x.dim() == 2
if x.dim() == 3:
(B, T, D) = x.shape # noqa: N806
else:
(T, D) = x.shape # noqa: N806
B = 1 # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
if x.dim() == 3:
return z_.view(B, T, D)
else:
return z_

View file

@ -126,7 +126,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = event.span_id
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:

View file

@ -39,13 +39,7 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="inline::meta-reference-quantized",
pip_packages=(
META_REFERENCE_DEPS
+ [
"fbgemm-gpu",
"torchao==0.5.0",
]
),
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),

View file

@ -27,7 +27,7 @@ def supported_inference_models() -> List[Model]:
m
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3, ModelFamily.llama4}
or is_supported_safety_model(m)
)
]

View file

@ -33,9 +33,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
@ -55,10 +53,22 @@ class LiteLLMOpenAIMixin(
Inference,
NeedsRequestProviderData,
):
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
def __init__(
self,
model_entries,
api_key_from_config: Optional[str],
provider_data_api_key_field: str,
openai_compat_api_base: str | None = None,
):
ModelRegistryHelper.__init__(self, model_entries)
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
self.api_base = openai_compat_api_base
if openai_compat_api_base:
self.is_openai_compat = True
else:
self.is_openai_compat = False
async def initialize(self):
pass
@ -98,6 +108,7 @@ class LiteLLMOpenAIMixin(
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -111,6 +122,9 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
if self.is_openai_compat:
params["model"] = "openai/" + params["model"]
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
@ -208,6 +222,7 @@ class LiteLLMOpenAIMixin(
return {
"model": request.model,
"api_key": api_key,
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),

View file

@ -573,21 +573,24 @@ async def convert_message_to_openai_dict_new(
content=await _convert_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
tool_calls = [
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
]
params = {}
if tool_calls:
params = {"tool_calls": tool_calls}
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
]
or None,
**params,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
@ -801,7 +804,7 @@ def _convert_openai_logprobs(
- token, logprob
"""
if not logprobs:
if not logprobs or not logprobs.content:
return None
return [

View file

@ -224,7 +224,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
@ -302,8 +304,12 @@ def chat_completion_request_to_messages(
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
# llama3.2 and llama3.3 models follow the same tool prompt format
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
else:
messages = request.messages
@ -471,7 +477,11 @@ def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2 and llama3.3 models follow the same tool prompt format
return ToolPromptFormat.python_list
else: