forked from phoenix-oss/llama-stack-mirror
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
c52ccc4bbd
commit
530d4bdfe1
85 changed files with 1267 additions and 1683 deletions
95
llama_stack/models/llama/llama4/args.py
Normal file
95
llama_stack/models/llama/llama4/args.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
|
||||
class LoRAArgs(BaseModel):
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
class MoEArgs(BaseModel):
|
||||
num_experts: int = -1
|
||||
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
|
||||
auto_scale_F: bool = ( # noqa: N815
|
||||
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
|
||||
)
|
||||
top_k: int = 1
|
||||
interleave_moe_layer_step: int = 1
|
||||
|
||||
|
||||
class Size(BaseModel):
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class VisionArgs(BaseModel):
|
||||
image_size: Size
|
||||
patch_size: Size
|
||||
|
||||
# parameters for the encoder transformer
|
||||
dim: int
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
mlp_ratio: float
|
||||
output_dim: int
|
||||
|
||||
pixel_shuffle_ratio: float
|
||||
|
||||
|
||||
class ModelArgs(BaseModel):
|
||||
dim: int = -1
|
||||
n_layers: int = -1
|
||||
n_heads: int = -1
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_exp: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
attention_chunk_size: Optional[int] = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
attn_temperature_tuning: bool = False
|
||||
floor_scale: float = 8192.0
|
||||
attn_scale: float = 0.1
|
||||
|
||||
vision_args: Optional[VisionArgs] = None
|
||||
moe_args: Optional[MoEArgs] = None
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate(self) -> "ModelArgs":
|
||||
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
|
||||
assert self.n_heads % self.n_kv_heads == 0, (
|
||||
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
|
||||
return self
|
|
@ -13,7 +13,7 @@ import torch
|
|||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
|
@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
|
||||
ResizeNormalizeImageTransform,
|
||||
VariableSizeImageTransform,
|
||||
)
|
||||
|
||||
from ..llama3.tool_utils import ToolUtils
|
||||
from .args import VisionArgs
|
||||
from .datatypes import LLMInput
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -54,7 +48,7 @@ class TransformedImage:
|
|||
aspect_ratio: Tuple[int, int]
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
|
@ -171,7 +165,7 @@ class ChatFormat:
|
|||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = convert_rgba_to_rgb(image)
|
||||
image = convert_image_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
if image_tiles.shape[0] > 1:
|
||||
|
|
57
llama_stack/models/llama/llama4/datatypes.py
Normal file
57
llama_stack/models/llama/llama4/datatypes.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding:
|
||||
embedding: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerInput:
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: Union[torch.Tensor, int]
|
||||
image_embedding: Optional[MaskedEmbedding] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMOutput:
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
TransformerOutput = LLMOutput
|
58
llama_stack/models/llama/llama4/ffn.py
Normal file
58
llama_stack/models/llama/llama4/ffn.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
do_reduce: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.do_reduce = do_reduce
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "mlp.fc1_weight" in state_dict:
|
||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||
state_dict[prefix + "w1.weight"] = w1
|
||||
state_dict[prefix + "w3.weight"] = w3
|
||||
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
|
||||
|
||||
def forward(self, x):
|
||||
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
|
||||
out = F.linear(x, self.w2.weight)
|
||||
if self.do_reduce:
|
||||
return reduce_from_model_parallel_region(out)
|
||||
return out
|
313
llama_stack/models/llama/llama4/generation.py
Normal file
313
llama_stack/models/llama/llama4/generation.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import codecs
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, RawContent, RawMessage
|
||||
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
||||
from .model import Transformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
||||
|
||||
|
||||
class Llama4:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
**params,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
# TODO: params.json should always have correct vocab_size
|
||||
if model_args.vocab_size == -1:
|
||||
model_args.vocab_size = tokenizer.n_words
|
||||
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
||||
print("Model args:\n", model_args.model_dump_json(indent=2))
|
||||
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
moe_num_experts=model_args.moe_args.num_experts,
|
||||
)
|
||||
print("Loaded checkpoint")
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama4(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
params = self.model.args
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
cprint("Input to model:\n", "yellow")
|
||||
for inp in llm_inputs:
|
||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
image_embedding = None
|
||||
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
|
||||
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
||||
image_mask = image_mask.unsqueeze(-1)
|
||||
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
||||
|
||||
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
|
||||
image_embedding = MaskedEmbedding(
|
||||
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
||||
mask=image_mask,
|
||||
)
|
||||
|
||||
xformer_input = TransformerInput(
|
||||
tokens=tokens[:, prev_pos:cur_pos],
|
||||
tokens_position=prev_pos,
|
||||
image_embedding=image_embedding,
|
||||
)
|
||||
xformer_output = self.model.forward(xformer_input)
|
||||
logits = xformer_output.logits
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
442
llama_stack/models/llama/llama4/model.py
Normal file
442
llama_stack/models/llama/llama4/model.py
Normal file
|
@ -0,0 +1,442 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
from .datatypes import TransformerInput, TransformerOutput
|
||||
from .ffn import FeedForward
|
||||
from .moe import MoE
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
class L2Norm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor):
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scale_factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
# TODO: this module needs to be moved into a separate file since it can be used by
|
||||
# the vision encoder as well.
|
||||
def __init__(
|
||||
self,
|
||||
args: ModelArgs,
|
||||
use_qk_norm: bool,
|
||||
use_rope: bool,
|
||||
add_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
self.use_qk_norm = use_qk_norm
|
||||
# For attention temperature tuning
|
||||
self.attn_temperature_tuning = args.attn_temperature_tuning
|
||||
self.floor_scale = args.floor_scale
|
||||
self.attn_scale = args.attn_scale
|
||||
|
||||
self.n_heads = args.n_heads
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=add_bias,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.qk_norm = None
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = L2Norm(args.norm_eps)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "wqkv.weight" in state_dict:
|
||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
|
||||
if r != 0:
|
||||
raise ValueError(
|
||||
f"shape={tuple(wqkv.shape)} is not divisible by "
|
||||
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
|
||||
state_dict[prefix + "wq.weight"] = wq
|
||||
state_dict[prefix + "wk.weight"] = wk
|
||||
state_dict[prefix + "wv.weight"] = wv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
if self.use_rope:
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if self.use_qk_norm:
|
||||
xq = self.qk_norm(xq)
|
||||
xk = self.qk_norm(xk)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
# while working at very long context
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
|
||||
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
|
||||
attn_scales = attn_scales.view(1, seqlen, 1, 1)
|
||||
xq = xq * attn_scales
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
xk = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
xv = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
|
||||
|
||||
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
||||
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
output = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
|
||||
|
||||
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
|
||||
|
||||
use_rope = not self.is_nope_layer
|
||||
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
|
||||
|
||||
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
|
||||
|
||||
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
|
||||
self.feed_forward = MoE(
|
||||
dim=args.dim,
|
||||
hidden_dim=int(args.ffn_exp * args.dim),
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
multiple_of=args.multiple_of,
|
||||
moe_args=args.moe_args,
|
||||
)
|
||||
else:
|
||||
hidden_dim = int(4 * args.dim)
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if args.ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||
|
||||
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
|
||||
elif prefix + "feed_forward.norm.weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
|
||||
|
||||
for k in (
|
||||
"feed_forward.experts.mlp",
|
||||
"feed_forward.mlp_shared",
|
||||
"attention.wo",
|
||||
"attention.wqkv",
|
||||
):
|
||||
if prefix + k + "._extra_state" in state_dict:
|
||||
state_dict.pop(prefix + k + "._extra_state")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
global_attn_mask: Optional[torch.Tensor],
|
||||
local_attn_mask: Optional[torch.Tensor],
|
||||
):
|
||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||
# if chunked local attention is not used
|
||||
if self.is_nope_layer or local_attn_mask is None:
|
||||
mask = global_attn_mask
|
||||
else:
|
||||
mask = local_attn_mask
|
||||
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, args))
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
args.dim // args.n_heads,
|
||||
args.max_seq_len * 2,
|
||||
args.rope_theta,
|
||||
args.use_scaled_rope,
|
||||
)
|
||||
vision_args = self.args.vision_args
|
||||
if vision_args:
|
||||
# circular import otherwise until we refactor out Attention
|
||||
from .vision.embedding import VisionEmbeddings
|
||||
|
||||
self.vision_embeddings = VisionEmbeddings(vision_args)
|
||||
self.vision_projection = ColumnParallelLinear(
|
||||
vision_args.output_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "rope.freqs" in state_dict:
|
||||
state_dict.pop(prefix + "rope.freqs")
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, model_input: TransformerInput) -> TransformerOutput:
|
||||
tokens = model_input.tokens
|
||||
start_pos = model_input.tokens_position
|
||||
assert isinstance(start_pos, int), (
|
||||
"This implementation does not support different start positions per batch item"
|
||||
)
|
||||
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
if image_embedding := model_input.image_embedding:
|
||||
h_image = self.vision_projection(image_embedding.embedding)
|
||||
h = h * ~image_embedding.mask + h_image * image_embedding.mask
|
||||
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
global_attn_mask, local_attn_mask = None, None
|
||||
if seqlen > 1:
|
||||
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if global_attn_mask.device.type == torch.device("mps").type:
|
||||
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
|
||||
|
||||
if chunk_size := self.args.attention_chunk_size:
|
||||
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
|
||||
return TransformerOutput(logits=output)
|
||||
|
||||
|
||||
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
|
||||
# in the iRoPE architecture
|
||||
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
|
||||
block_pos = torch.abs(
|
||||
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
|
||||
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
|
||||
)
|
||||
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
|
||||
mask = (block_pos == 0) & (token_pos <= 0)
|
||||
return mask.to(device)
|
214
llama_stack/models/llama/llama4/moe.py
Normal file
214
llama_stack/models/llama/llama4/moe.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# ruff: noqa: N806
|
||||
# pyre-strict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .args import MoEArgs
|
||||
from .ffn import FeedForward
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_local_experts: int,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.num_local_experts = num_local_experts
|
||||
self.dim = dim
|
||||
divide_factor = fs_init.get_model_parallel_world_size()
|
||||
|
||||
self.w1: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w2: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w3: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
|
||||
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
|
||||
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
routed_in_egD: torch.Tensor, # noqa: N803
|
||||
) -> torch.Tensor:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
|
||||
x_egD = routed_in_egD.view(e, -1, D)
|
||||
|
||||
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
|
||||
out_egD = out_egD.view(-1, D)
|
||||
|
||||
return out_egD
|
||||
|
||||
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
return torch.bmm(middle_out_egF, w2)
|
||||
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""
|
||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||
Several commonly used annotations include:
|
||||
- a: bsz*slen
|
||||
- E: number of experts
|
||||
- e: number of local experts per ep (n_experts/ep)
|
||||
- D: hidden dimension
|
||||
- d: D/tp
|
||||
- F: model dimension
|
||||
- G: number of tokens per expert (a * capacity_factor / E)
|
||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||
|
||||
Examples:
|
||||
x_aD [a, D]
|
||||
routed_in_etG_D [et*G, D]
|
||||
x_eGD: [e, G, D]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
ffn_dim_multiplier: float,
|
||||
multiple_of: int,
|
||||
moe_args: MoEArgs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.moe_args = moe_args
|
||||
|
||||
hidden_dim_denom: float = 1
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim_denom = moe_args.capacity_factor + 1
|
||||
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
|
||||
# custom dim factor multiplier
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
||||
|
||||
hidden_dim += -hidden_dim % multiple_of
|
||||
|
||||
num_local_experts: int = moe_args.num_experts
|
||||
dtype: torch.dtype = torch.get_default_dtype()
|
||||
self.experts = Experts(
|
||||
num_local_experts,
|
||||
dim,
|
||||
hidden_dim,
|
||||
)
|
||||
|
||||
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
|
||||
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
|
||||
|
||||
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
|
||||
_, slen, D = x_bsD.shape
|
||||
x_aD = x_bsD.view(-1, D)
|
||||
|
||||
a = x_aD.shape[0]
|
||||
|
||||
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
|
||||
|
||||
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
|
||||
router_scores = (
|
||||
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
|
||||
.scatter_(1, router_indices_aK, router_scores_aK)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||
|
||||
router_scores = torch.sigmoid(router_scores)
|
||||
|
||||
routed_in_EG_D: Tensor = torch.gather(
|
||||
x_aD,
|
||||
dim=0,
|
||||
index=router_indices.reshape(-1, 1).expand(-1, D),
|
||||
)
|
||||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||
|
||||
out_aD = self.shared_expert(x_aD)
|
||||
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
||||
|
||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||
out_aD.scatter_add_(
|
||||
dim=0,
|
||||
index=router_indices_EG_D,
|
||||
src=routed_out_eg_D.view(-1, D),
|
||||
)
|
||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||
return out_aD.view(-1, slen, D)
|
||||
|
||||
|
||||
def divide_exact(numerator: int, denominator: int) -> int:
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
436
llama_stack/models/llama/llama4/preprocess.py
Normal file
436
llama_stack/models/llama/llama4/preprocess.py
Normal file
|
@ -0,0 +1,436 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image, ImageFile
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
IMAGE_RES = 448
|
||||
|
||||
|
||||
class ResizeNormalizeImageTransform:
|
||||
def __init__(
|
||||
self,
|
||||
size_width=None,
|
||||
size_height=None,
|
||||
) -> None:
|
||||
self._size_width = size_width or IMAGE_RES
|
||||
self._size_height = size_height or IMAGE_RES
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
|
||||
self.tv_transform = tv.Compose(
|
||||
[
|
||||
tv.Resize((self._size_height, self._size_width)),
|
||||
tv.ToTensor(),
|
||||
tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, image: Image.Image) -> torch.Tensor:
|
||||
return self.tv_transform(image)
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, width in value:
|
||||
possible_resolutions.append((height * patch_size, width * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(
|
||||
max(new_size_without_distortion[1], 1),
|
||||
max(new_size_without_distortion[0], 1),
|
||||
),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
|
@ -4,20 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from llama_stack.models.llama.prompt_format import (
|
||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from ..prompt_format import (
|
||||
Llama4UseCase,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
|
|
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
225
llama_stack/models/llama/llama4/quantization/loader.py
Normal file
225
llama_stack/models/llama/llama4/quantization/loader.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from ...quantize_impls import ffn_swiglu
|
||||
|
||||
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
|
||||
|
||||
def experts_batched_swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor, # (e, g, D)
|
||||
w1: Tensor, # (e, D, F)
|
||||
w3: Tensor, # (e, D, F)
|
||||
w2: Tensor, # (e, F, D)
|
||||
) -> torch.Tensor:
|
||||
from ...quantize_impls import bmm_nt
|
||||
|
||||
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
|
||||
return bmm_nt(middle_out_egF, w2)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
use_rich_progress: bool = True,
|
||||
) -> Transformer:
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
Int4ScaledWeights,
|
||||
load_fp8,
|
||||
load_int4,
|
||||
quantize_fp8,
|
||||
quantize_int4,
|
||||
)
|
||||
|
||||
rank = get_model_parallel_rank()
|
||||
|
||||
def should_quantize_block(block: nn.Module) -> bool:
|
||||
if not isinstance(block, TransformerBlock):
|
||||
return False
|
||||
|
||||
is_moe = isinstance(block.feed_forward, MoE)
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
# skip quantization on first and last layers
|
||||
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||
|
||||
return is_moe
|
||||
|
||||
use_rich_progress = use_rich_progress and rank == 0
|
||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
||||
if os.path.isfile(int4_scales_path):
|
||||
log_status(f"Rank {rank}: Loading int4 scales")
|
||||
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = int4_scales[key]
|
||||
return load_int4(
|
||||
weight,
|
||||
scale,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
else:
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
log_status(f"Rank {rank}: Loading fp8 scales")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = fp8_scales[key]
|
||||
return load_fp8(
|
||||
weight,
|
||||
scale,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
processed_blocks = 0
|
||||
try:
|
||||
if use_rich_progress:
|
||||
progress.start()
|
||||
|
||||
for _, block in model.named_modules():
|
||||
if not should_quantize_block(block):
|
||||
continue
|
||||
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||
|
||||
# Quantize only routed experts, not shared
|
||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||
moe = block.feed_forward
|
||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.experts, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||
setattr(
|
||||
moe.experts,
|
||||
key,
|
||||
apply_quantization(
|
||||
f"{prefix}.experts.{key}",
|
||||
param.transpose(1, 2).contiguous(),
|
||||
),
|
||||
)
|
||||
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
# Quantize shared experts
|
||||
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.shared_expert, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
|
||||
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
|
||||
|
||||
processed_blocks += 1
|
||||
update_status(message=None, completed=processed_blocks)
|
||||
|
||||
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
||||
|
||||
param_count = 0
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
param_count += 1
|
||||
|
||||
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
|
||||
finally:
|
||||
if use_rich_progress:
|
||||
progress.stop()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
||||
def logging_callbacks(
|
||||
use_rich_progress: bool,
|
||||
rank: int,
|
||||
model: Transformer,
|
||||
should_quantize_block: Callable[[nn.Module], bool],
|
||||
):
|
||||
console = None
|
||||
if use_rich_progress:
|
||||
from rich.console import Console
|
||||
|
||||
console = Console(highlight=False)
|
||||
|
||||
def log_status(message: str) -> None:
|
||||
if use_rich_progress:
|
||||
console.print(message)
|
||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||
log.info(message)
|
||||
|
||||
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||
progress = None
|
||||
if use_rich_progress:
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(),
|
||||
BarColumn(complete_style="green", finished_style="bright_green"),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("ETA:"),
|
||||
TimeRemainingColumn(),
|
||||
TextColumn("[bold]{task.fields[status]}"),
|
||||
console=console,
|
||||
expand=True,
|
||||
)
|
||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||
|
||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
||||
if use_rich_progress:
|
||||
if message is not None:
|
||||
progress.update(task_id, status=message)
|
||||
if completed is not None:
|
||||
progress.update(task_id, completed=completed)
|
||||
elif rank == 0 and completed and completed % 10 == 0:
|
||||
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||
|
||||
return progress, log_status, update_status
|
|
@ -4,9 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
@ -59,8 +56,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
|||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|python_start|>",
|
||||
"<|python_end|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
|
@ -85,8 +80,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
|
|||
"vision", 1041, 7
|
||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||
|
||||
# 201134, ..., 201143
|
||||
LLAMA4_REASONING_SPECIAL_TOKENS = [
|
||||
"<|reasoning_reserved_special_token_0|>",
|
||||
"<|reasoning_reserved_special_token_1|>",
|
||||
"<|reasoning_reserved_special_token_2|>",
|
||||
"<|reasoning_reserved_special_token_3|>",
|
||||
"<|reasoning_reserved_special_token_4|>",
|
||||
"<|reasoning_reserved_special_token_5|>",
|
||||
"<|reasoning_reserved_special_token_6|>",
|
||||
"<|reasoning_reserved_special_token_7|>",
|
||||
"<|reasoning_thinking_start|>",
|
||||
"<|reasoning_thinking_end|>",
|
||||
]
|
||||
|
||||
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
|
||||
LLAMA4_SPECIAL_TOKENS = (
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
|
||||
)
|
||||
|
||||
BASIC_SPECIAL_TOKENS = [
|
||||
"<|begin_of_text|>",
|
||||
|
@ -155,6 +165,9 @@ class Tokenizer:
|
|||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||
|
||||
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
|
||||
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
|
||||
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom|>"],
|
||||
|
|
209
llama_stack/models/llama/llama4/vision/embedding.py
Normal file
209
llama_stack/models/llama/llama4/vision/embedding.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
|
||||
from ..args import VisionArgs
|
||||
from .encoder import VisionEncoder
|
||||
|
||||
|
||||
class PixelShuffle(nn.Module):
|
||||
def __init__(self, ps_ratio):
|
||||
super().__init__()
|
||||
self.ps_ratio = ps_ratio
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, N, C], N = number of patches
|
||||
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
|
||||
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
|
||||
hh = ww = int(math.sqrt(x.shape[1]))
|
||||
x = x.reshape(x.shape[0], hh, ww, -1)
|
||||
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
|
||||
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
return pixel_shuffle_patches
|
||||
|
||||
|
||||
def pixel_shuffle_op(input_x, ps_ratio):
|
||||
n, w, h, c = input_x.size()
|
||||
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
input_x = input_x.view(
|
||||
n,
|
||||
int(h * ps_ratio),
|
||||
int(w * ps_ratio),
|
||||
int(c / (ps_ratio * ps_ratio)),
|
||||
)
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
return input_x
|
||||
|
||||
|
||||
class SimpleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
bias: bool = True,
|
||||
dropout: float = 0.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.non_linearity(self.c_proj(hidden))
|
||||
|
||||
|
||||
class PixelShuffleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ps_ratio: float,
|
||||
input_dim: int,
|
||||
output_dim: int = 4096,
|
||||
add_fc: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle = PixelShuffle(ps_ratio)
|
||||
self.mlp = SimpleMLP(
|
||||
int(input_dim // (ps_ratio**2)),
|
||||
output_dim,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
self.fc = nn.Identity()
|
||||
if add_fc:
|
||||
self.fc = ColumnParallelLinear(
|
||||
output_dim,
|
||||
output_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = self.pixel_shuffle(encoded_patches)
|
||||
return self.fc(self.mlp(encoded_patches))
|
||||
|
||||
|
||||
class VisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, args: VisionArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
image_size = args.image_size
|
||||
patch_size = args.patch_size
|
||||
self.vision_encoder = VisionEncoder(
|
||||
image_size=(image_size.height, image_size.width),
|
||||
patch_size=(patch_size.height, patch_size.width),
|
||||
dim=args.dim,
|
||||
layers=args.n_layers,
|
||||
heads=args.n_heads,
|
||||
mlp_ratio=args.mlp_ratio,
|
||||
)
|
||||
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
|
||||
self.vision_adapter = PixelShuffleMLP(
|
||||
ps_ratio=args.pixel_shuffle_ratio,
|
||||
input_dim=args.dim,
|
||||
output_dim=args.output_dim,
|
||||
)
|
||||
|
||||
self.output_dim = args.output_dim
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
original_sd = self.state_dict()
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
|
||||
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
|
||||
|
||||
def _get_empty_sequence(self, h):
|
||||
return torch.zeros(
|
||||
h.shape[0],
|
||||
h.shape[1],
|
||||
self.output_dim,
|
||||
device=h.device,
|
||||
dtype=h.dtype,
|
||||
)
|
||||
|
||||
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
|
||||
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||
def forward(
|
||||
self,
|
||||
image_batch: List[List[torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
h_ref: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
images_flattened = [image for sample in image_batch for image in sample]
|
||||
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
|
||||
embedding = self.vision_encoder(images_flattened)
|
||||
projected_embedding = self.vision_adapter(embedding)
|
||||
|
||||
h_image = self._get_empty_sequence(h_ref)
|
||||
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
|
||||
|
||||
|
||||
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
|
||||
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
|
||||
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
|
||||
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
|
||||
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
|
||||
|
||||
assert not torch.isnan(encoded_patches_proj).any()
|
||||
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
|
||||
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
|
||||
)
|
||||
|
||||
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
|
||||
for index in range(h_image.size(0)):
|
||||
encoded_patches_per_sample = encoded_patches_list[index]
|
||||
sample_image_mask = image_mask[index]
|
||||
|
||||
if encoded_patches_per_sample.numel() == 0:
|
||||
continue
|
||||
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
|
||||
-1, encoded_patches_per_sample.size(-1)
|
||||
)
|
||||
|
||||
n_tokens_to_fill = sample_image_mask.sum()
|
||||
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
|
||||
|
||||
h_image[index].masked_scatter_(
|
||||
sample_image_mask.expand(-1, h_image.size(-1)),
|
||||
encoded_patches_per_sample[:n_tokens_to_fill],
|
||||
)
|
||||
|
||||
return h_image
|
411
llama_stack/models/llama/llama4/vision/encoder.py
Normal file
411
llama_stack/models/llama/llama4/vision/encoder.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from torch import einsum
|
||||
|
||||
from ..args import ModelArgs
|
||||
from ..model import Attention
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||
"""Conv2D Patching layer with model parallelism.
|
||||
Column parallel over unfolded input.
|
||||
Arguments:
|
||||
in_channels: Input channels.
|
||||
out_channels: Output channels.
|
||||
kernel_size: Size of convolution kernel.
|
||||
stride (default 1): Stride for convolution.
|
||||
bias (default False): Use bias in Conv2d.
|
||||
Input: (bsz, in_channels, height, width)
|
||||
Output: (bsz, num_tokens, out_channels)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
||||
self._linear = ColumnParallelLinear(
|
||||
in_channels * kernel_size[0] * kernel_size[1],
|
||||
out_channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._unfold(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self._linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class _FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
dropout: float,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.c_proj(hidden)
|
||||
|
||||
|
||||
class _TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert d_model % n_head == 0
|
||||
self.n_heads = n_head
|
||||
self.head_dim = d_model // self.n_heads
|
||||
|
||||
attn_args = ModelArgs(
|
||||
dim=d_model,
|
||||
head_dim=self.head_dim,
|
||||
n_heads=self.n_heads,
|
||||
n_kv_heads=self.n_heads,
|
||||
)
|
||||
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = _FeedForward(
|
||||
dim=d_model,
|
||||
hidden_dim=int(mlp_ratio * d_model),
|
||||
dropout=0.0,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.gated = gated
|
||||
if gated:
|
||||
self.gate_attn = nn.Parameter(torch.zeros(1))
|
||||
self.gate_ffn = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||
|
||||
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
|
||||
x = x + _gate_ffn * self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class _Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
_TransformerBlock(
|
||||
d_model=dim,
|
||||
n_head=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
gated=gated,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
|
||||
out = []
|
||||
for idx, r in enumerate(self.resblocks):
|
||||
if return_intermediate is not None and idx in return_intermediate:
|
||||
out.append(x)
|
||||
x = r(x, mask=mask, freq_cis=freq_cis)
|
||||
if return_intermediate is not None:
|
||||
return x, torch.stack(out, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
class PackingIndex:
|
||||
Z = 0 # Z (time) coordinate of the token in the original sample
|
||||
Y = 1 # Y (height) coordinate of the token in the original sample
|
||||
X = 2 # X (width) coordinate of the token in the original sample
|
||||
TIME = 3 # Total number of time units (frames) in the original sample
|
||||
HEIGHT = 4 # Height of the original sample
|
||||
WIDTH = 5 # Width of the original sample
|
||||
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
|
||||
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
|
||||
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
|
||||
|
||||
# Total size of the enum, remember to update this!
|
||||
NUM_METADATA = 8
|
||||
|
||||
# Note: For padding tokens IDX = -1
|
||||
# For cls tokens, IDX = -2
|
||||
ID_CLS_TOKEN = -2
|
||||
ID_PAD_TOKEN = -1
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
patch_size: Tuple[int, int],
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
in_channels: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (
|
||||
self.image_size[0] // self.patch_size[0],
|
||||
self.image_size[1] // self.patch_size[1],
|
||||
)
|
||||
self.conv1 = ColumnParallelConv2dPatch(
|
||||
in_channels=in_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
)
|
||||
scale = dim**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
|
||||
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
|
||||
)
|
||||
|
||||
self.ln_pre = LayerNorm(dim)
|
||||
self.ln_post = LayerNorm(dim)
|
||||
self.transformer = _Transformer(
|
||||
dim,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
|
||||
# NOTE: hack for the fixed res
|
||||
image_h, image_w = self.image_size
|
||||
patch_h, patch_w = self.patch_size
|
||||
idx_h, idx_w = image_h // patch_h, image_w // patch_w
|
||||
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
|
||||
img_idx = img_idx.reshape(idx_h * idx_w, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
|
||||
|
||||
packed_img_idx = torch.empty(
|
||||
img_idx.shape[0],
|
||||
img_idx.shape[1],
|
||||
PackingIndex.NUM_METADATA - 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
|
||||
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
|
||||
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
|
||||
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
|
||||
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
|
||||
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
|
||||
self.packed_img_idx = packed_img_idx # for positional embedding load hook
|
||||
|
||||
# compute rope freqs
|
||||
rope_freq = self.get_rope_freqs(dim // heads // 2)
|
||||
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
|
||||
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
|
||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||
# disable RoPE for padding and cls tokens
|
||||
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
|
||||
# compute complex freqs
|
||||
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
# xlf automatically broadcasts
|
||||
self.freq_cis = self.freq_cis.squeeze(0)
|
||||
self.n_heads = heads // fs_init.get_model_parallel_world_size()
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def get_rope_freqs(self, dim, theta=10000):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
return freqs
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def compute_rope_freqs(self, freqs, t):
|
||||
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
||||
freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
return freqs
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
|
||||
raise ValueError(
|
||||
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
|
||||
)
|
||||
|
||||
batch_size, token_per_image, _ = self.packed_img_idx.shape
|
||||
# Input points for idx are [x, y, w, h]
|
||||
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
|
||||
total_windows, window_size, _ = idx.shape
|
||||
|
||||
# Grid values are [-1, 1] and coords are w, h
|
||||
grid = (
|
||||
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
|
||||
)[None, ...]
|
||||
|
||||
# In this mode, cls token has no position embedding
|
||||
if orig_pos_embed is not None:
|
||||
posemb = (
|
||||
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
|
||||
sample = F.grid_sample(
|
||||
posemb, grid, padding_mode="zeros"
|
||||
) # padding tokens / class token will get zero for posemb
|
||||
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
|
||||
sample = torch.where(
|
||||
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
|
||||
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
|
||||
sample,
|
||||
)
|
||||
|
||||
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
|
||||
|
||||
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
|
||||
|
||||
if return_state_dict:
|
||||
return state_dict
|
||||
|
||||
def apply_class_embedding(self, x):
|
||||
x = torch.cat(
|
||||
[
|
||||
x,
|
||||
self.class_embedding.to(x.dtype)
|
||||
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
],
|
||||
dim=1,
|
||||
) # shape = [*, grid ** 2 + 1, width]
|
||||
return x
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
|
||||
if images.ndim == 5:
|
||||
num_concurrent_media = 1
|
||||
bsz, num_chunks, nch, h, w = images.shape
|
||||
else:
|
||||
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
|
||||
|
||||
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
# patch embedding
|
||||
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
x = self.conv1(x) # shape = [*, width, grid ** 2]
|
||||
_, ntok, dim = x.shape
|
||||
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
|
||||
|
||||
# apply cls token
|
||||
x = self.apply_class_embedding(x)
|
||||
ntok += 1
|
||||
|
||||
# apply position embeddings
|
||||
if self.positional_embedding_vlm is not None:
|
||||
x = x + self.positional_embedding_vlm.to(x.dtype)
|
||||
|
||||
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = x.view(bsz * num_concurrent_media, -1, dim)
|
||||
freq_cis = self.freq_cis.to(images.device)
|
||||
|
||||
tf_output = self.transformer(
|
||||
x,
|
||||
freq_cis=freq_cis,
|
||||
)
|
||||
|
||||
int_x = None
|
||||
if isinstance(tf_output, tuple):
|
||||
x, int_x = tf_output
|
||||
else:
|
||||
x = tf_output
|
||||
x = self.ln_post(x)
|
||||
|
||||
# remove cls token output
|
||||
x = x[:, :-1, :]
|
||||
|
||||
# add and output x + int_x features
|
||||
if int_x is not None:
|
||||
int_x = int_x[:, :-1, :, :]
|
||||
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
|
||||
x = torch.cat([x, int_x], dim=-1)
|
||||
|
||||
return x
|
Loading…
Add table
Add a link
Reference in a new issue