refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
return [new_mp_rank * old_mp_size // new_mp_size]
elif old_mp_size % new_mp_size == 0:
# Merge old MP shards into a single one
mp_factor = old_mp_size // new_mp_size
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
else:
raise ValueError(
f"Either old MP size or new MP size should be a multiple of the other: "
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
)
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:
return state_dicts[0] # type: ignore
if moe_num_experts is not None:
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
return reshard_mp(
state_dicts,
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
)
_WEIGHT_ROW_KEY = {
"feed_forward.w2",
"feed_forward.mlp.fc2",
"attention.wo",
"feed_forward.mlp.fc2_weight",
"feed_forward.w_out_shared_DF.weight",
"attn.wo.weight",
"mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
_WEIGHT_COLUMN_KEY = {
"output",
"feed_forward.(w1|w3)",
"feed_forward.mlp.(fc1|fc3)",
"feed_forward.mlp.fc1_weight",
"attention.(wk|wq|wv|wqkv).weight",
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
"attn.(wk|wq|wv).weight",
"attn.(wk|wq|wv).bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"conv1._linear.weight",
"tok_embeddings.weight",
"vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
return concat_or_chunk(values, dim=1).flatten(0, 1)
elif "qkv" in key:
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
elif "wk.weight" in key or "wv.weight" in key:
# Support MP > #kv_head
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return state_dicts[0][key].clone()
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
for future in concurrent.futures.as_completed(mappings):
output[mappings[future]] = future.result()
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())
for key in keys:
if routed_regex.search(key):
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
return state_dict

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import base64
from enum import Enum
from io import BytesIO
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
@ -98,6 +89,29 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens"
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
@ -140,292 +154,25 @@ class RawMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall)
class GenerationResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
source: Literal["input"] | Literal["output"]
# index within the batch
batch_idx: int
# whether generation for this item is already finished. note that tokens can
# get returned even afterwards since other items in the batch can still be generating tokens
finished: bool
# because a batch is parallel processed, useful decoding for one item can correspond to processing
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
ignore_token: bool
@json_schema_type
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
import re
import torch
from torch import nn
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
"""Hadamard transform.
This function performs the Hadamard transform on the input tensor 'x'.
The Hadamard transform is a linear transformation that multiplies the input
tensor by the Hadamard matrix of dimension n x n, where n is the size of
the last dimension of the input tensor.
"""
*_, n = x.shape
m = int(math.log2(n))
assert n == 1 << m, "n must be a power of 2"
x = x[..., None]
inv_sqrt2 = 0.5**0.5
for _ in range(m):
top = x[..., ::2, :] + x[..., 1::2, :]
bot = x[..., ::2, :] - x[..., 1::2, :]
x = torch.cat((top, bot), dim=-1)
x *= inv_sqrt2
res = x.squeeze(-2)
return res
class HadamardModule(torch.nn.Module):
"""A module that applies the Hadamard transform to the input tensor.
Args:
group_size: The size of the groups that the input tensor will be divided into
before applying the Hadamard transform.
"""
def __init__(self, group_size: int) -> None:
super().__init__()
self.group_size = group_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
reshape_back = False
orig_shape = x.shape
if self.group_size != x.shape[-1]:
reshape_back = True
x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size)
x = hadamard_transform(x)
if reshape_back:
x = x.reshape(orig_shape)
return x
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
"""
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern
"layers.<digit>.feed_forward.w2", where <digit> is one or more digits. When such a layer is found,
it is replaced with a new sequential module that consists of a HadamardModule followed by the original
layer. The HadamardModule applies the Hadamard transform to the input tensor.
See `SpinQuant <https://arxiv.org/abs/2405.16406>_` paper for more details.
Args:
model: An instance of 'torch.nn.Module' (e.g., Transformer model).
prefix: A string prefix to add to the full name of each child module.
Returns:
None
"""
pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2"
for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
del module
setattr(model, module_name, new_module)
else:
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class QuantizationScheme(Enum):
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
@dataclass
class QuantizationArgs:
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k == "scheme":
setattr(self, k, QuantizationScheme(v))
else:
if hasattr(self, k):
setattr(self, k, v)
@dataclass
class LoRAArgs:
rank: int
scale: float
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
max_batch_size: int = 32
max_seq_len: int = 2048
# vision model params
vision_chunk_size: int = -1 # image resolution for image models
vision_max_num_chunks: int = 4
vision_num_cross_attention_layers: int = -1
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k == "lora_args":
setattr(self, k, LoRAArgs(**v))
elif k == "quantization_args":
setattr(self, k, QuantizationArgs(**v))
else:
if hasattr(self, k):
setattr(self, k, v)
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
assert self.n_kv_heads <= self.n_heads
assert self.n_heads % self.n_kv_heads == 0
assert self.dim % self.n_heads == 0

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import json
import uuid
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from .tokenizer import Tokenizer
from .tool_utils import ToolUtils

View file

@ -0,0 +1,367 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
class Llama3:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
device = torch.device(device)
if (
device.type == "cuda"
and not torch.cuda.is_available()
or device.type == "xpu"
and not torch.xpu.is_available()
):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
if not torch.distributed.is_initialized():
if device.type == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device.type == "cuda":
torch.cuda.set_device(local_rank)
elif device.type == "xpu":
torch.xpu.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer.get_instance()
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words
def build_model():
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else:
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
for inp in model_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
text_only_inference,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def completion(
self,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -16,7 +16,7 @@ from typing import List, Optional
from termcolor import colored
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import (

View file

@ -0,0 +1,305 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn
from .args import ModelArgs
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * torch.pi / freqs
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
return torch.where(
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
new_freqs,
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# https://github.com/pytorch/pytorch/issues/100005
# torch.triu is buggy when the device is mps: filled values are
# nan instead of 0.
if mask.device.type == torch.device("mps").type:
mask = torch.nan_to_num(mask, nan=0.0)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
import torch
import torch.nn.functional as F
from .utils import get_negative_inf_value, to_2tuple
logger = getLogger()
def resize_local_position_embedding(orig_pos_embed, grid_size):
"""
Resize position embedding for vision encoder.
Original position embedding is [n_tiles * n_tiles + 1, dim]
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
"""
new_grid_size = to_2tuple(grid_size)
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
new_pos_emb_tok, new_pos_emb_img = (
orig_pos_embed[:1],
orig_pos_embed[1:],
)
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
return new_pos_embed
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a local position embedding for vision encoder and uses it
to initialize the global position embedding.
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
pos_embed = pos_and_cls_embed[1:]
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
grid_size = to_2tuple(grid_size)
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
return pos_and_cls_embed
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a global position embedding for vision encoder and resizes it to new size.
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
# first remove cls token
pos_embed = pos_and_cls_embed[:, :, 1:]
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
xs_old, ys_old, ntok, dim = pos_embed.shape
old_grid_size = int(math.sqrt(ntok))
# move to correct form for interpolation
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
pos_embed = pos_embed.unsqueeze(0)
# interpolate
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
pos_embed = pos_embed.permute(0, 3, 1, 2)
pos_embed_resized = F.interpolate(
pos_embed,
size=new_size,
mode="bilinear",
align_corners=True,
)
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
# move it back in place
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
# interpolate cls token
cls_embed = cls_embed.permute(2, 3, 0, 1)
cls_embed_resized = F.interpolate(
cls_embed,
size=(x_scale, y_scale),
mode="bilinear",
align_corners=True,
)
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
# add cls token back in
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
return pos_and_cls_embed
def build_encoder_attention_mask(
x: torch.Tensor,
ar: torch.Tensor,
ntok: int,
num_chunks: int,
n_heads: int,
):
"""
Build vision encoder attention mask that omits padding tokens.
"""
masks = []
for arx in ar:
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
mask_i[: arx[0] * arx[1], :ntok] = 0
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
mask_i = mask_i.unsqueeze(0)
masks.append(mask_i)
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
return masks
def expand_num_tokens_to_mult8(x):
num_pad_tokens = 8 - (x.shape[-2] % 8)
if num_pad_tokens == 0:
return x, 0
else:
return (
torch.cat(
[
x,
torch.zeros(
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
dtype=x.dtype,
device=x.device,
),
],
dim=-2,
),
num_pad_tokens,
)
def contract_num_tokens_from_mult8(x, num_pad_tokens):
if num_pad_tokens == 0:
return x
return x[:, :, :-num_pad_tokens]

View file

@ -0,0 +1,408 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from collections import defaultdict
from logging import getLogger
from typing import Any, Optional, Set, Tuple
import torch
import torchvision.transforms as tv
from PIL import Image
from torchvision.transforms import functional as F
IMAGE_RES = 224
logger = getLogger()
class VariableSizeImageTransform(object):
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
The algorithm will NOT distort the image fit a certain aspect ratio, because
that leads to a significant degradation in image quality.
It can be summarized in 6 steps:
1. Find all possible canvas combinations of max_num_chunks;
2. Find the best canvas to fit the image;
3. Resize without distortion
4. Pad
5. Normalize
6. Chunk
For example, if an input image is of size 300x800, patch_size of 224,
and max_num_chunks = 8, it will find the closest aspect ratio that
is allowed within 8 image chunks, with some restrictions.
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
giving a total of 8 chunks.
If resize_to_max_canvas, the image will be resized (without distortion),
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
where we maintain the original aspect ratio and pad with zeros value for the rest.
This approach minimizes the amount of padding required for any arbitrary resolution.
However, if limit_upscaling_to_patch_size is set to True,
the upscaling will be limited to the patch size. In the example above,
the image would remain 300x800 (no upscaling), and then padded to 448:896.
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
patches are coming from the resizing and chunking.
"""
def __init__(self, size: int = IMAGE_RES) -> None:
self.size = size
logger.info(f"VariableSizeImageTransform size: {self.size}")
self.to_tensor = tv.ToTensor()
self._mean = (0.48145466, 0.4578275, 0.40821073)
self._std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
)
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
n (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
factors_set.add(i)
factors_set.add(n // i)
return factors_set
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(self.get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, depth in value:
possible_resolutions.append((height * patch_size, depth * patch_size))
return possible_resolutions
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_width, original_height = image_size
target_width, target_height = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_width, new_height
def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_width, new_height = target_size
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
new_im.paste(image)
return new_im
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(1, 3, 0, 2, 4).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
If target_size requires upscaling the image, the user can set max_upscaling_size to
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
modifying target_size works as a boundary for the image's largest side.
Args:
resample (str): Resampling method used when resizing images.
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
max_upscaling_size (int): The maximum size to upscale the image to.
If None, there is no limit.
Examples:
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(600, 300) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (2000, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 100) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 2000
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = None
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
"""
image_width, image_height = image.size
image_size = (image_width, image_height)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
if max_upscaling_size is not None:
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
target_size = (new_target_width, new_target_height)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
image = F.resize(
image,
(new_size_without_distortion[1], new_size_without_distortion[0]),
interpolation=self.resample,
)
return image
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_width, original_height = image_size
# get all possible resolutions heights/widths
target_widths, target_heights = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
def __call__(
self,
image: Image.Image,
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[Any, Any]:
"""
Args:
image (PIL.Image): Image to be resized.
max_num_chunks (int): Maximum number of chunks to split the image into.
normalize_img (bool): Whether to normalize the image.
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
"""
assert max_num_chunks > 0
assert isinstance(image, Image.Image), type(image)
w, h = image.size
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
possible_resolutions = torch.tensor(possible_resolutions)
best_resolution = self.get_best_fit(
image_size=(w, h),
possible_resolutions=possible_resolutions,
resize_to_max_canvas=resize_to_max_canvas,
)
max_upscaling_size = None if resize_to_max_canvas else self.size
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
image = self._pad(image, best_resolution)
image = self.to_tensor(image)
if normalize_img:
image = self.normalize(image)
ratio_w, ratio_h = (
best_resolution[0] // self.size,
best_resolution[1] // self.size,
)
image = self._split(image, ratio_w, ratio_h) # type: ignore
ar = (ratio_h, ratio_w)
return image, ar

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import collections
import torch
def get_negative_inf_value(dtype):
return torch.finfo(dtype).min
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import (
from llama_stack.apis.inference import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,316 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from ...datatypes import QuantizationMode
from ...quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
from ..model import Transformer, TransformerBlock
from ..multimodal.model import CrossAttentionTransformer
def swiglu_wrapper(
self,
x: Tensor,
):
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
elif quantization_mode == QuantizationMode.int4_mixed:
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
def convert_to_fp8_quantized_model(
model: Transformer,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer:
# Move weights to GPU with quantization
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
if os.path.isfile(fp8_scales_path):
print("Loading fp8 scales...")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
fp8_activation_scale_ub,
)
else:
print("Quantizing fp8 weights from bf16...")
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=device,
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device=device)
return model
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
device: Device to use.
group_size: Group size for quantization.
precision: Precision of quantization.
scales_precision: Precision of scales.
lora_rank: Rank of LoRA adaptor.
lora_scale: Scale of LoRA adaptor.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=False,
device=None,
# quantization parameters
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
self.adaptor = nn.Sequential()
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
if self.adaptor is not None:
adaptor_out = self.adaptor(input_) * self.lora_scale
return module_out + adaptor_out
return module_out
class Int8WeightEmbedding(torch.nn.Embedding):
"""An embedding layer to load int8 weights.
Args:
num_embeddings: Number of embeddings.
embedding_dim: Embedding dimension.
padding_idx: Padding index.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
device=None,
) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
class Int8WeightLinear(torch.nn.Linear):
"""A linear layer to load int8 weights.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
Note that the weights of embedding and output layers are quantized to int8.
"""
device = None
for module_name, module in model.named_children():
if module_name == "output":
quantized_module = Int8WeightLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif module_name == "tok_embeddings":
quantized_module = Int8WeightEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
bias=False,
group_size=group_size,
lora_rank=lora_rank,
lora_scale=lora_scale,
device=device,
)
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
return model
def convert_to_int4_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
lora_rank = None
lora_scale = None
else:
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))

View file

@ -12,8 +12,7 @@
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from ..datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,
JsonCustomToolGenerator,

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path

View file

@ -16,7 +16,8 @@ import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")

View file

@ -3,10 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -4,12 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import textwrap

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from pathlib import Path

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel, model_validator
class QuantizationScheme(Enum):
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
class LoRAArgs(BaseModel):
rank: int
scale: float
class MoEArgs(BaseModel):
num_experts: int = -1
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
auto_scale_F: bool = ( # noqa: N815
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
)
top_k: int = 1
interleave_moe_layer_step: int = 1
class Size(BaseModel):
height: int
width: int
class VisionArgs(BaseModel):
image_size: Size
patch_size: Size
# parameters for the encoder transformer
dim: int
n_layers: int
n_heads: int
mlp_ratio: float
output_dim: int
pixel_shuffle_ratio: float
class ModelArgs(BaseModel):
dim: int = -1
n_layers: int = -1
n_heads: int = -1
n_kv_heads: Optional[int] = None
head_dim: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_exp: Optional[float] = None
norm_eps: float = 1e-5
attention_chunk_size: Optional[int] = None
rope_theta: float = 500000
use_scaled_rope: bool = False
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
attn_temperature_tuning: bool = False
floor_scale: float = 8192.0
attn_scale: float = 0.1
vision_args: Optional[VisionArgs] = None
moe_args: Optional[MoEArgs] = None
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
max_batch_size: int = 32
max_seq_len: int = 2048
@model_validator(mode="after")
def validate(self) -> "ModelArgs":
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
assert self.n_heads % self.n_kv_heads == 0, (
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
return self

View file

@ -13,7 +13,7 @@ import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
ResizeNormalizeImageTransform,
VariableSizeImageTransform,
)
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@ -54,7 +48,7 @@ class TransformedImage:
aspect_ratio: Tuple[int, int]
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
@ -171,7 +165,7 @@ class ChatFormat:
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = convert_rgba_to_rgb(image)
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
@dataclass
class MaskedEmbedding:
embedding: torch.Tensor
mask: torch.Tensor
@dataclass
class LLMInput:
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
tokens: torch.Tensor
# images are already pre-processed (resized, tiled, etc.)
images: Optional[List[torch.Tensor]] = None
@dataclass
class TransformerInput:
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
tokens: torch.Tensor
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: Union[torch.Tensor, int]
image_embedding: Optional[MaskedEmbedding] = None
@dataclass
class LLMOutput:
logits: torch.Tensor
TransformerOutput = LLMOutput

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from typing import Any, Dict, List
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import nn
from torch.nn import functional as F
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
do_reduce: bool = True,
):
super().__init__()
self.do_reduce = do_reduce
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
state_dict[prefix + "w1.weight"] = w1
state_dict[prefix + "w3.weight"] = w3
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
def forward(self, x):
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
out = F.linear(x, self.w2.weight)
if self.do_reduce:
return reduce_from_model_parallel_region(out)
return out

View file

@ -0,0 +1,313 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import codecs
import io
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs
from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class Llama4:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
**params,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
tokenizer = Tokenizer.get_instance()
# TODO: params.json should always have correct vocab_size
if model_args.vocab_size == -1:
model_args.vocab_size = tokenizer.n_words
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2))
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama4(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
cprint("Input to model:\n", "yellow")
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens), "grey")
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if echo:
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
image_embedding = None
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask,
)
xformer_input = TransformerInput(
tokens=tokens[:, prev_pos:cur_pos],
tokens_position=prev_pos,
image_embedding=image_embedding,
)
xformer_output = self.model.forward(xformer_input)
logits = xformer_output.logits
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def completion(
self,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,442 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn
from .args import ModelArgs
from .datatypes import TransformerInput, TransformerOutput
from .ffn import FeedForward
from .moe import MoE
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class L2Norm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
# TODO: this module needs to be moved into a separate file since it can be used by
# the vision encoder as well.
def __init__(
self,
args: ModelArgs,
use_qk_norm: bool,
use_rope: bool,
add_bias: bool = False,
):
super().__init__()
self.use_rope = use_rope
self.use_qk_norm = use_qk_norm
# For attention temperature tuning
self.attn_temperature_tuning = args.attn_temperature_tuning
self.floor_scale = args.floor_scale
self.attn_scale = args.attn_scale
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=add_bias,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
if r != 0:
raise ValueError(
f"shape={tuple(wqkv.shape)} is not divisible by "
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
)
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
state_dict[prefix + "wq.weight"] = wq
state_dict[prefix + "wk.weight"] = wk
state_dict[prefix + "wv.weight"] = wv
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
if self.use_rope:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm:
xq = self.qk_norm(xq)
xk = self.qk_norm(xk)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
if self.attn_temperature_tuning and not self.use_rope:
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
attn_scales = attn_scales.view(1, seqlen, 1, 1)
xq = xq * attn_scales
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
xk = self.cache_k[:bsz, : start_pos + seqlen]
xv = self.cache_v[:bsz, : start_pos + seqlen]
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
xk = xk.repeat_interleave(self.n_rep, dim=1)
xv = xv.repeat_interleave(self.n_rep, dim=1)
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(attn_output)
return output
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
use_rope = not self.is_nope_layer
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
self.feed_forward = MoE(
dim=args.dim,
hidden_dim=int(args.ffn_exp * args.dim),
ffn_dim_multiplier=args.ffn_dim_multiplier,
multiple_of=args.multiple_of,
moe_args=args.moe_args,
)
else:
hidden_dim = int(4 * args.dim)
hidden_dim = int(2 * hidden_dim / 3)
if args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=hidden_dim,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
elif prefix + "feed_forward.norm.weight" in state_dict:
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
for k in (
"feed_forward.experts.mlp",
"feed_forward.mlp_shared",
"attention.wo",
"attention.wqkv",
):
if prefix + k + "._extra_state" in state_dict:
state_dict.pop(prefix + k + "._extra_state")
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor],
local_attn_mask: Optional[torch.Tensor],
):
# The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used
if self.is_nope_layer or local_attn_mask is None:
mask = global_attn_mask
else:
mask = local_attn_mask
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs, **kwargs) -> None:
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(TransformerBlock(layer_id, args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
self.freqs_cis = precompute_freqs_cis(
args.dim // args.n_heads,
args.max_seq_len * 2,
args.rope_theta,
args.use_scaled_rope,
)
vision_args = self.args.vision_args
if vision_args:
# circular import otherwise until we refactor out Attention
from .vision.embedding import VisionEmbeddings
self.vision_embeddings = VisionEmbeddings(vision_args)
self.vision_projection = ColumnParallelLinear(
vision_args.output_dim,
args.dim,
bias=False,
init_method=lambda x: x,
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs")
@torch.inference_mode()
def forward(self, model_input: TransformerInput) -> TransformerOutput:
tokens = model_input.tokens
start_pos = model_input.tokens_position
assert isinstance(start_pos, int), (
"This implementation does not support different start positions per batch item"
)
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if image_embedding := model_input.image_embedding:
h_image = self.vision_projection(image_embedding.embedding)
h = h * ~image_embedding.mask + h_image * image_embedding.mask
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
global_attn_mask, local_attn_mask = None, None
if seqlen > 1:
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
# https://github.com/pytorch/pytorch/issues/100005
# torch.triu is buggy when the device is mps: filled values are
# nan instead of 0.
if global_attn_mask.device.type == torch.device("mps").type:
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
if chunk_size := self.args.attention_chunk_size:
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
h = self.norm(h)
output = self.output(h).float()
return TransformerOutput(logits=output)
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
# in the iRoPE architecture
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
block_pos = torch.abs(
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
)
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0)
return mask.to(device)

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N806
# pyre-strict
from typing import Any, Dict, List
import fairscale.nn.model_parallel.initialize as fs_init
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import Tensor, nn
from torch.nn import functional as F
from .args import MoEArgs
from .ffn import FeedForward
class Experts(nn.Module):
def __init__(
self,
num_local_experts: int,
dim: int,
hidden_dim: int,
) -> None:
super().__init__()
dtype = torch.get_default_dtype()
self.num_local_experts = num_local_experts
self.dim = dim
divide_factor = fs_init.get_model_parallel_world_size()
self.w1: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
)
self.w2: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
divide_exact(hidden_dim, divide_factor),
dim,
dtype=dtype,
)
)
self.w3: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
self.prefix = prefix
if prefix + "moe_w_in_eD_F" in state_dict:
e = self.num_local_experts
D = self.dim
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
def forward(
self,
routed_in_egD: torch.Tensor, # noqa: N803
) -> torch.Tensor:
e = self.num_local_experts
D = self.dim
x_egD = routed_in_egD.view(e, -1, D)
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
out_egD = out_egD.view(-1, D)
return out_egD
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
return torch.bmm(middle_out_egF, w2)
class MoE(torch.nn.Module):
"""
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include:
- a: bsz*slen
- E: number of experts
- e: number of local experts per ep (n_experts/ep)
- D: hidden dimension
- d: D/tp
- F: model dimension
- G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP)
Examples:
x_aD [a, D]
routed_in_etG_D [et*G, D]
x_eGD: [e, G, D]
"""
def __init__(
self,
dim: int,
hidden_dim: int,
ffn_dim_multiplier: float,
multiple_of: int,
moe_args: MoEArgs,
) -> None:
super().__init__()
self.moe_args = moe_args
hidden_dim_denom: float = 1
if moe_args.auto_scale_F:
hidden_dim_denom = moe_args.capacity_factor + 1
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
if moe_args.auto_scale_F:
hidden_dim = int(hidden_dim / hidden_dim_denom)
hidden_dim += -hidden_dim % multiple_of
num_local_experts: int = moe_args.num_experts
dtype: torch.dtype = torch.get_default_dtype()
self.experts = Experts(
num_local_experts,
dim,
hidden_dim,
)
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "w_in_shared_FD.weight" in state_dict:
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
_, slen, D = x_bsD.shape
x_aD = x_bsD.view(-1, D)
a = x_aD.shape[0]
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
router_scores = (
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
.scatter_(1, router_indices_aK, router_scores_aK)
.transpose(0, 1)
)
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
router_scores = torch.sigmoid(router_scores)
routed_in_EG_D: Tensor = torch.gather(
x_aD,
dim=0,
index=router_indices.reshape(-1, 1).expand(-1, D),
)
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD)
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_(
dim=0,
index=router_indices_EG_D,
src=routed_out_eg_D.view(-1, D),
)
out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D)
def divide_exact(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
return numerator // denominator

View file

@ -0,0 +1,436 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from collections import defaultdict
from typing import Optional, Set, Tuple
import torch
import torchvision.transforms as tv
from PIL import Image, ImageFile
from torchvision.transforms import functional as F
ImageFile.LOAD_TRUNCATED_IMAGES = True
IMAGE_RES = 448
class ResizeNormalizeImageTransform:
def __init__(
self,
size_width=None,
size_height=None,
) -> None:
self._size_width = size_width or IMAGE_RES
self._size_height = size_height or IMAGE_RES
self._mean = (0.5, 0.5, 0.5)
self._std = (0.5, 0.5, 0.5)
self.tv_transform = tv.Compose(
[
tv.Resize((self._size_height, self._size_width)),
tv.ToTensor(),
tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
),
]
)
def __call__(self, image: Image.Image) -> torch.Tensor:
return self.tv_transform(image)
class VariableSizeImageTransform(object):
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
The algorithm will NOT distort the image fit a certain aspect ratio, because
that leads to a significant degradation in image quality.
It can be summarized in 6 steps:
1. Find all possible canvas combinations of max_num_chunks;
2. Find the best canvas to fit the image;
3. Resize without distortion
4. Pad
5. Normalize
6. Chunk
For example, if an input image is of size 300x800, patch_size of 224,
and max_num_chunks = 8, it will find the closest aspect ratio that
is allowed within 8 image chunks, with some restrictions.
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
giving a total of 8 chunks.
If resize_to_max_canvas, the image will be resized (without distortion),
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
where we maintain the original aspect ratio and pad with zeros value for the rest.
This approach minimizes the amount of padding required for any arbitrary resolution.
However, if limit_upscaling_to_patch_size is set to True,
the upscaling will be limited to the patch size. In the example above,
the image would remain 300x800 (no upscaling), and then padded to 448:896.
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
patches are coming from the resizing and chunking.
"""
def __init__(self, size: int = IMAGE_RES) -> None:
self.size = size
self.to_tensor = tv.ToTensor()
self._mean = (0.5, 0.5, 0.5)
self._std = (0.5, 0.5, 0.5)
self.normalize = tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
)
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
n (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
factors_set.add(i)
factors_set.add(n // i)
return factors_set
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(self.get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, width in value:
possible_resolutions.append((height * patch_size, width * patch_size))
return possible_resolutions
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_width, original_height = image_size
target_width, target_height = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_width, new_height
def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_width, new_height = target_size
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
new_im.paste(image)
return new_im
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(1, 3, 0, 2, 4).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
If target_size requires upscaling the image, the user can set max_upscaling_size to
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
modifying target_size works as a boundary for the image's largest side.
Args:
resample (str): Resampling method used when resizing images.
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
max_upscaling_size (int): The maximum size to upscale the image to.
If None, there is no limit.
Examples:
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(600, 300) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (2000, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 100) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 2000
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = None
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
"""
image_width, image_height = image.size
image_size = (image_width, image_height)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
if max_upscaling_size is not None:
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
target_size = (new_target_width, new_target_height)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
image = F.resize(
image,
(
max(new_size_without_distortion[1], 1),
max(new_size_without_distortion[0], 1),
),
interpolation=self.resample,
)
return image
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_width, original_height = image_size
# get all possible resolutions heights/widths
target_widths, target_heights = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
def __call__(
self,
image: Image.Image,
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Args:
image (PIL.Image): Image to be resized.
max_num_chunks (int): Maximum number of chunks to split the image into.
normalize_img (bool): Whether to normalize the image.
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
"""
assert max_num_chunks > 0
assert isinstance(image, Image.Image), type(image)
w, h = image.size
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
possible_resolutions = torch.tensor(possible_resolutions)
best_resolution = self.get_best_fit(
image_size=(w, h),
possible_resolutions=possible_resolutions,
resize_to_max_canvas=resize_to_max_canvas,
)
max_upscaling_size = None if resize_to_max_canvas else self.size
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
image = self._pad(image, best_resolution)
image = self.to_tensor(image)
if normalize_img:
image = self.normalize(image)
ratio_w, ratio_h = (
best_resolution[0] // self.size,
best_resolution[1] // self.size,
)
image = self._split(image, ratio_w, ratio_h) # type: ignore
ar = (ratio_h, ratio_w)
return image, ar

View file

@ -4,20 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from io import BytesIO
from pathlib import Path
from typing import List
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
from llama_stack.models.llama.prompt_format import (
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
from ..prompt_format import (
Llama4UseCase,
TextCompletionContent,
UseCase,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from typing import Callable, Optional
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn
from torch.nn import functional as F
from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
def swiglu_wrapper_no_reduce(
self,
x: Tensor,
):
from ...quantize_impls import ffn_swiglu
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
def experts_batched_swiglu_wrapper(
self,
x: Tensor, # (e, g, D)
w1: Tensor, # (e, D, F)
w3: Tensor, # (e, D, F)
w2: Tensor, # (e, F, D)
) -> torch.Tensor:
from ...quantize_impls import bmm_nt
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
return bmm_nt(middle_out_egF, w2)
def convert_to_quantized_model(
model: Transformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
use_rich_progress: bool = True,
) -> Transformer:
from ...quantize_impls import (
Fp8ScaledWeights,
Int4ScaledWeights,
load_fp8,
load_int4,
quantize_fp8,
quantize_int4,
)
rank = get_model_parallel_rank()
def should_quantize_block(block: nn.Module) -> bool:
if not isinstance(block, TransformerBlock):
return False
is_moe = isinstance(block.feed_forward, MoE)
if quantization_mode == QuantizationMode.fp8_mixed:
# skip quantization on first and last layers
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
return is_moe
use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True)
def apply_quantization(key, weight):
scale = int4_scales[key]
return load_int4(
weight,
scale,
output_device=torch.device("cuda"),
)
else:
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path):
log_status(f"Rank {rank}: Loading fp8 scales")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
def apply_quantization(key, weight):
scale = fp8_scales[key]
return load_fp8(
weight,
scale,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
else:
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
def apply_quantization(_, weight):
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
processed_blocks = 0
try:
if use_rich_progress:
progress.start()
for _, block in model.named_modules():
if not should_quantize_block(block):
continue
update_status(f"Rank {rank} - Layer {block.layer_id}")
# Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(
f"{prefix}.experts.{key}",
param.transpose(1, 2).contiguous(),
),
)
if quantization_mode == QuantizationMode.int4_mixed:
# Quantize shared experts
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
for key in ("w1", "w3", "w2"):
param = getattr(moe.shared_expert, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
processed_blocks += 1
update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA")
param_count = 0
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
parameter.data = parameter.to(device="cuda")
param_count += 1
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
finally:
if use_rich_progress:
progress.stop()
return model
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
def logging_callbacks(
use_rich_progress: bool,
rank: int,
model: Transformer,
should_quantize_block: Callable[[nn.Module], bool],
):
console = None
if use_rich_progress:
from rich.console import Console
console = Console(highlight=False)
def log_status(message: str) -> None:
if use_rich_progress:
console.print(message)
elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message)
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
progress = None
if use_rich_progress:
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
progress = Progress(
SpinnerColumn(),
BarColumn(complete_style="green", finished_style="bright_green"),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TextColumn("ETA:"),
TimeRemainingColumn(),
TextColumn("[bold]{task.fields[status]}"),
console=console,
expand=True,
)
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
if use_rich_progress:
if message is not None:
progress.update(task_id, status=message)
if completed is not None:
progress.update(task_id, completed=completed)
elif rank == 0 and completed and completed % 10 == 0:
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
return progress, log_status, update_status

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path
@ -59,8 +56,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
"<|python_start|>",
"<|python_end|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 6
@ -85,8 +80,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
"vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
# 201134, ..., 201143
LLAMA4_REASONING_SPECIAL_TOKENS = [
"<|reasoning_reserved_special_token_0|>",
"<|reasoning_reserved_special_token_1|>",
"<|reasoning_reserved_special_token_2|>",
"<|reasoning_reserved_special_token_3|>",
"<|reasoning_reserved_special_token_4|>",
"<|reasoning_reserved_special_token_5|>",
"<|reasoning_reserved_special_token_6|>",
"<|reasoning_reserved_special_token_7|>",
"<|reasoning_thinking_start|>",
"<|reasoning_thinking_end|>",
]
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
LLAMA4_SPECIAL_TOKENS = (
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
)
BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>",
@ -155,6 +165,9 @@ class Tokenizer:
self.eot_id: int = self.special_tokens["<|eot|>"]
self.eom_id: int = self.special_tokens["<|eom|>"]
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom|>"],

View file

@ -0,0 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from ..args import VisionArgs
from .encoder import VisionEncoder
class PixelShuffle(nn.Module):
def __init__(self, ps_ratio):
super().__init__()
self.ps_ratio = ps_ratio
def forward(self, x):
# x: [B, N, C], N = number of patches
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
hh = ww = int(math.sqrt(x.shape[1]))
x = x.reshape(x.shape[0], hh, ww, -1)
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
return pixel_shuffle_patches
def pixel_shuffle_op(input_x, ps_ratio):
n, w, h, c = input_x.size()
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
input_x = input_x.permute(0, 2, 1, 3).contiguous()
input_x = input_x.view(
n,
int(h * ps_ratio),
int(w * ps_ratio),
int(c / (ps_ratio * ps_ratio)),
)
input_x = input_x.permute(0, 2, 1, 3).contiguous()
return input_x
class SimpleMLP(torch.nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
bias: bool = True,
dropout: float = 0.0,
act_layer: Callable = nn.GELU,
):
super().__init__()
# layers
self.c_fc = ColumnParallelLinear(
dim,
hidden_dim,
bias=bias,
gather_output=False,
)
self.c_proj = RowParallelLinear(
hidden_dim,
hidden_dim,
bias=bias,
input_is_parallel=True,
)
self.non_linearity = act_layer()
self.dropout = dropout
def forward(self, x):
hidden = self.c_fc(x)
hidden = self.non_linearity(hidden)
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
return self.non_linearity(self.c_proj(hidden))
class PixelShuffleMLP(torch.nn.Module):
def __init__(
self,
ps_ratio: float,
input_dim: int,
output_dim: int = 4096,
add_fc: bool = False,
):
super().__init__()
self.pixel_shuffle = PixelShuffle(ps_ratio)
self.mlp = SimpleMLP(
int(input_dim // (ps_ratio**2)),
output_dim,
bias=False,
dropout=0.0,
act_layer=nn.GELU,
)
self.fc = nn.Identity()
if add_fc:
self.fc = ColumnParallelLinear(
output_dim,
output_dim,
bias=False,
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = self.pixel_shuffle(encoded_patches)
return self.fc(self.mlp(encoded_patches))
class VisionEmbeddings(torch.nn.Module):
def __init__(self, args: VisionArgs):
super().__init__()
self.args = args
image_size = args.image_size
patch_size = args.patch_size
self.vision_encoder = VisionEncoder(
image_size=(image_size.height, image_size.width),
patch_size=(patch_size.height, patch_size.width),
dim=args.dim,
layers=args.n_layers,
heads=args.n_heads,
mlp_ratio=args.mlp_ratio,
)
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
self.vision_adapter = PixelShuffleMLP(
ps_ratio=args.pixel_shuffle_ratio,
input_dim=args.dim,
output_dim=args.output_dim,
)
self.output_dim = args.output_dim
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
return_state_dict: bool = False,
) -> None:
original_sd = self.state_dict()
for k in state_dict:
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
def _get_empty_sequence(self, h):
return torch.zeros(
h.shape[0],
h.shape[1],
self.output_dim,
device=h.device,
dtype=h.dtype,
)
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
# each image is a tensor of shape [num_tiles, C, H, W]
def forward(
self,
image_batch: List[List[torch.Tensor]],
image_mask: torch.Tensor,
h_ref: torch.Tensor,
) -> torch.Tensor:
images_flattened = [image for sample in image_batch for image in sample]
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
embedding = self.vision_encoder(images_flattened)
projected_embedding = self.vision_adapter(embedding)
h_image = self._get_empty_sequence(h_ref)
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
assert not torch.isnan(encoded_patches_proj).any()
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
)
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
for index in range(h_image.size(0)):
encoded_patches_per_sample = encoded_patches_list[index]
sample_image_mask = image_mask[index]
if encoded_patches_per_sample.numel() == 0:
continue
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
-1, encoded_patches_per_sample.size(-1)
)
n_tokens_to_fill = sample_image_mask.sum()
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
h_image[index].masked_scatter_(
sample_image_mask.expand(-1, h_image.size(-1)),
encoded_patches_per_sample[:n_tokens_to_fill],
)
return h_image

View file

@ -0,0 +1,411 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from torch import einsum
from ..args import ModelArgs
from ..model import Attention
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, height, width)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x = self._linear(x)
return x
class _FeedForward(torch.nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float,
act_layer: Callable = nn.GELU,
):
super().__init__()
# layers
self.c_fc = ColumnParallelLinear(
dim,
hidden_dim,
bias=True,
gather_output=False,
init_method=lambda x: x,
)
self.c_proj = RowParallelLinear(
hidden_dim,
dim,
bias=True,
input_is_parallel=True,
init_method=lambda x: x,
)
self.non_linearity = act_layer()
self.dropout = dropout
def forward(self, x):
hidden = self.c_fc(x)
hidden = self.non_linearity(hidden)
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
return self.c_proj(hidden)
class _TransformerBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
gated: bool = False,
):
super().__init__()
assert d_model % n_head == 0
self.n_heads = n_head
self.head_dim = d_model // self.n_heads
attn_args = ModelArgs(
dim=d_model,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_heads,
)
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
self.ln_1 = LayerNorm(d_model)
self.mlp = _FeedForward(
dim=d_model,
hidden_dim=int(mlp_ratio * d_model),
dropout=0.0,
act_layer=act_layer,
)
self.ln_2 = LayerNorm(d_model)
self.gated = gated
if gated:
self.gate_attn = nn.Parameter(torch.zeros(1))
self.gate_ffn = nn.Parameter(torch.zeros(1))
def attention(
self,
x: torch.Tensor,
freq_cis: Optional[torch.Tensor] = None,
):
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
freq_cis: Optional[torch.Tensor] = None,
):
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
x = x + _gate_ffn * self.mlp(self.ln_2(x))
return x
class _Transformer(nn.Module):
def __init__(
self,
dim: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
gated: bool = False,
):
super().__init__()
self.resblocks = nn.ModuleList(
[
_TransformerBlock(
d_model=dim,
n_head=heads,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
gated=gated,
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
out = []
for idx, r in enumerate(self.resblocks):
if return_intermediate is not None and idx in return_intermediate:
out.append(x)
x = r(x, mask=mask, freq_cis=freq_cis)
if return_intermediate is not None:
return x, torch.stack(out, dim=-1)
return x
class PackingIndex:
Z = 0 # Z (time) coordinate of the token in the original sample
Y = 1 # Y (height) coordinate of the token in the original sample
X = 2 # X (width) coordinate of the token in the original sample
TIME = 3 # Total number of time units (frames) in the original sample
HEIGHT = 4 # Height of the original sample
WIDTH = 5 # Width of the original sample
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
# Total size of the enum, remember to update this!
NUM_METADATA = 8
# Note: For padding tokens IDX = -1
# For cls tokens, IDX = -2
ID_CLS_TOKEN = -2
ID_PAD_TOKEN = -1
class VisionEncoder(nn.Module):
def __init__(
self,
image_size: Tuple[int, int],
patch_size: Tuple[int, int],
dim: int,
layers: int,
heads: int,
mlp_ratio: float,
in_channels: int = 3,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.grid_size = (
self.image_size[0] // self.patch_size[0],
self.image_size[1] // self.patch_size[1],
)
self.conv1 = ColumnParallelConv2dPatch(
in_channels=in_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = dim**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
self.positional_embedding_vlm = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
)
self.ln_pre = LayerNorm(dim)
self.ln_post = LayerNorm(dim)
self.transformer = _Transformer(
dim,
layers,
heads,
mlp_ratio,
act_layer=nn.GELU,
)
# NOTE: hack for the fixed res
image_h, image_w = self.image_size
patch_h, patch_w = self.patch_size
idx_h, idx_w = image_h // patch_h, image_w // patch_w
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
img_idx = img_idx.reshape(idx_h * idx_w, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
packed_img_idx = torch.empty(
img_idx.shape[0],
img_idx.shape[1],
PackingIndex.NUM_METADATA - 1,
dtype=torch.int32,
)
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
self.packed_img_idx = packed_img_idx # for positional embedding load hook
# compute rope freqs
rope_freq = self.get_rope_freqs(dim // heads // 2)
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
# disable RoPE for padding and cls tokens
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
# compute complex freqs
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
# xlf automatically broadcasts
self.freq_cis = self.freq_cis.squeeze(0)
self.n_heads = heads // fs_init.get_model_parallel_world_size()
self._register_load_state_dict_pre_hook(self.load_hook)
def get_rope_freqs(self, dim, theta=10000):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
return freqs
@torch.amp.autocast("cuda", enabled=False)
def compute_rope_freqs(self, freqs, t):
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
freqs = freqs.repeat_interleave(2, dim=-1)
return freqs
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
raise ValueError(
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
)
batch_size, token_per_image, _ = self.packed_img_idx.shape
# Input points for idx are [x, y, w, h]
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
total_windows, window_size, _ = idx.shape
# Grid values are [-1, 1] and coords are w, h
grid = (
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
)[None, ...]
# In this mode, cls token has no position embedding
if orig_pos_embed is not None:
posemb = (
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
)
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
sample = F.grid_sample(
posemb, grid, padding_mode="zeros"
) # padding tokens / class token will get zero for posemb
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
sample = torch.where(
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
sample,
)
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
if return_state_dict:
return state_dict
def apply_class_embedding(self, x):
x = torch.cat(
[
x,
self.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
return x
def forward(self, images: torch.Tensor) -> torch.Tensor:
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
if images.ndim == 5:
num_concurrent_media = 1
bsz, num_chunks, nch, h, w = images.shape
else:
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
# patch embedding
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
x = self.conv1(x) # shape = [*, width, grid ** 2]
_, ntok, dim = x.shape
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
# apply cls token
x = self.apply_class_embedding(x)
ntok += 1
# apply position embeddings
if self.positional_embedding_vlm is not None:
x = x + self.positional_embedding_vlm.to(x.dtype)
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
x = self.ln_pre(x)
x = x.view(bsz * num_concurrent_media, -1, dim)
freq_cis = self.freq_cis.to(images.device)
tf_output = self.transformer(
x,
freq_cis=freq_cis,
)
int_x = None
if isinstance(tf_output, tuple):
x, int_x = tf_output
else:
x = tf_output
x = self.ln_post(x)
# remove cls token output
x = x[:, :-1, :]
# add and output x + int_x features
if int_x is not None:
int_x = int_x[:, :-1, :, :]
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
x = torch.cat([x, int_x], dim=-1)
return x

View file

@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from .llama3.interface import LLama31Interface
from .llama3.template_data import (
@ -76,21 +73,22 @@ class UseCase(BaseModel):
text += dialog
text += "\n\n"
continue
elif isinstance(dialog, TextCompletionContent):
input_tokens, output_tokens = generator.text_completion_raw(
dialog.content,
temperature=0.1,
top_p=0.95,
max_gen_len=64,
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
temperature=0.0,
top_p=0.95,
max_gen_len=self.max_gen_len,
batch = [dialog]
method = (
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
)
input_tokens = []
output_tokens = []
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
result = token_results[0]
if result.source == "input":
input_tokens.append(result.token)
else:
output_tokens.append(result.token)
if result.finished:
break
text += "##### Input Prompt Format\n"
# FIXME: This is added to undo the hack in chat_formatter where
@ -126,27 +124,27 @@ class Llama4UseCase(UseCase):
text = ""
tokenizer = Tokenizer.get_instance()
temperature = 0.0
for dialog in self.dialogs:
if isinstance(dialog, str):
text += dialog
text += "\n\n"
continue
elif isinstance(dialog, TextCompletionContent):
# TODO pass the raw input and do the encoding in the text completion function
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
llm_input = LLMInput(tokens=input_tokens)
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
temperature=temperature,
max_gen_len=self.max_gen_len,
batch = [dialog]
method = (
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
)
input_tokens = []
output_tokens = []
for token_results in method(batch, echo=True, temperature=0.0):
result = token_results[0]
if result.source == "input":
input_tokens.append(result.token)
else:
output_tokens.append(result.token)
if result.finished:
break
text += "##### Input Prompt Format\n"
text += _code_block(tokenizer.decode(input_tokens))

View file

@ -0,0 +1,332 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# type: ignore
import collections
import logging
from typing import Optional, Tuple, Type, Union
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
except ImportError:
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
raise
import torch
from torch import Tensor, nn
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
class Int4ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Int4Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Int4Weights(
Int4ScaledWeights,
collections.namedtuple(
"Int4Weights",
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
),
):
pass
def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
# Recenter output and move to int8.
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
# Cutlass expects column major layout for scale and zero point,
# so we transpose here and make them contiguous.
scales = scales.view(x.shape[0], -1).t().contiguous()
zeros = zeros.view(x.shape[0], -1).t().contiguous()
return out, scales, zeros
def pack_int4(x: torch.Tensor) -> torch.Tensor:
# Given int8 x, pack adjacent int4 values into a single int8.
low_x = x[:, ::2]
high_x = x[:, 1::2]
# High bits need to left shift, this also masks off extra bits.
high_x = torch.bitwise_left_shift(high_x, 4)
# Low bits need to have sign bits removed.
low_x = torch.bitwise_and(low_x, 0xF)
# Recombine into a single value with bitwise or.
return torch.bitwise_or(low_x, high_x).contiguous()
def bmm_nt(
x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights],
num_tokens: Optional[Tensor] = None,
) -> Tensor:
if isinstance(w, Fp8ScaledWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
elif isinstance(w, Int4ScaledWeights):
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
else:
raise ValueError("Unsupported quantization type")
def ffn_swiglu(
x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights],
w3: Union[Fp8RowwiseWeights, Int4Weights],
w2: Union[Fp8RowwiseWeights, Int4Weights],
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
):
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def quantize_int4(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Quantize [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
if w.ndim >= 3:
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
scale = torch.stack(scale, dim=0)
zero_point = torch.stack(zero_point, dim=0)
else:
wq, scale, zero_point = int4_row_quantize(w)
wq = pack_int4(wq)
del w
return Int4Weights(
weight=wq.to(output_device),
scale=scale.to(output_device),
zero_point=zero_point.to(output_device),
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
scale=w_scale.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_int4(
w: Tensor,
scale: Tensor,
zero_point: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input INT4.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Int4Weights(
weight=w.to(torch.int8).to(device=output_device),
scale=scale.to(device=output_device),
zero_point=zero_point.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_dynamic(
x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights],
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
"""
if isinstance(w, Int4Weights):
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
else:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y
def ffn_swiglu_dynamic(
x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights],
w3: Union[Fp8RowwiseWeights, Int4Weights],
w2: Union[Fp8RowwiseWeights, Int4Weights],
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
assert x.dim() == 3 or x.dim() == 2
if x.dim() == 3:
(B, T, D) = x.shape # noqa: N806
else:
(T, D) = x.shape # noqa: N806
B = 1 # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
if x.dim() == 3:
return z_.view(B, T, D)
else:
return z_

View file

@ -4,24 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
from .datatypes import (
from .sku_types import (
CheckpointQuantizationFormat,
CoreModelId,
Model,
ModelFamily,
SamplingParams,
TopPSamplingStrategy,
)
LLAMA2_VOCAB_SIZE = 32000
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
)
def recommended_sampling_params() -> SamplingParams:
return SamplingParams(
strategy=TopPSamplingStrategy(
temperature=1.0,
top_p=0.9,
)
)
def llama2_family() -> List[Model]:
return [
*llama2_base_models(),
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b,
description="Llama 2 7b model",
huggingface_repo="meta-llama/Llama-2-7b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b,
description="Llama 2 13b model",
huggingface_repo="meta-llama/Llama-2-13b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b,
description="Llama 2 70b model",
huggingface_repo="meta-llama/Llama-2-70b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b,
description="Llama 3 70b model",
huggingface_repo="meta-llama/Llama-3-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b,
description="Llama 3.1 8b model",
huggingface_repo="meta-llama/Llama-3.1-8B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b,
description="Llama 3.1 70b model",
huggingface_repo="meta-llama/Llama-3.1-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
description="Llama 3.1 405b model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b,
description="Llama 3.2 1b model",
huggingface_repo="meta-llama/Llama-3.2-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b,
description="Llama 3.2 3b model",
huggingface_repo="meta-llama/Llama-3.2-3B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 3072,
"n_layers": 28,
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision,
description="Llama 3.2 11b vision model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision,
description="Llama 3.2 90b vision model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b_chat,
description="Llama 2 7b chat model",
huggingface_repo="meta-llama/Llama-2-7b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b_chat,
description="Llama 2 13b chat model",
huggingface_repo="meta-llama/Llama-2-13b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b_chat,
description="Llama 2 70b chat model",
huggingface_repo="meta-llama/Llama-2-70b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_8b_instruct,
description="Llama 3 8b instruct model",
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b_instruct,
description="Llama 3 70b instruct model",
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b_instruct,
description="Llama 3.1 8b instruct model",
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b_instruct,
description="Llama 3.1 70b instruct model",
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b instruct model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
description="Llama 3.1 405b instruct model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b_instruct,
description="Llama 3.2 1b instruct model",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_1b(),
pth_file_count=1,
),
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b_instruct,
description="Llama 3.2 3b instruct model",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_3b(),
pth_file_count=1,
),
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
description="Llama 3.2 11b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
description="Llama 3.2 90b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_3_70b_instruct,
description="Llama 3.3 70b instruct",
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
description="Llama Guard v3 1b 'int4' quantized system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
quantization_format=CheckpointQuantizationFormat.int4,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 12,
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_1b,
description="Llama Guard v3 1b system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,

View file

@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.core_model_id.value
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")