forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
209 lines
7 KiB
Python
209 lines
7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import math
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
|
|
|
from ..args import VisionArgs
|
|
from .encoder import VisionEncoder
|
|
|
|
|
|
class PixelShuffle(nn.Module):
|
|
def __init__(self, ps_ratio):
|
|
super().__init__()
|
|
self.ps_ratio = ps_ratio
|
|
|
|
def forward(self, x):
|
|
# x: [B, N, C], N = number of patches
|
|
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
|
|
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
|
|
hh = ww = int(math.sqrt(x.shape[1]))
|
|
x = x.reshape(x.shape[0], hh, ww, -1)
|
|
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
|
|
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
|
|
return pixel_shuffle_patches
|
|
|
|
|
|
def pixel_shuffle_op(input_x, ps_ratio):
|
|
n, w, h, c = input_x.size()
|
|
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
|
|
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
|
input_x = input_x.view(
|
|
n,
|
|
int(h * ps_ratio),
|
|
int(w * ps_ratio),
|
|
int(c / (ps_ratio * ps_ratio)),
|
|
)
|
|
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
|
return input_x
|
|
|
|
|
|
class SimpleMLP(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
hidden_dim: int,
|
|
bias: bool = True,
|
|
dropout: float = 0.0,
|
|
act_layer: Callable = nn.GELU,
|
|
):
|
|
super().__init__()
|
|
# layers
|
|
self.c_fc = ColumnParallelLinear(
|
|
dim,
|
|
hidden_dim,
|
|
bias=bias,
|
|
gather_output=False,
|
|
)
|
|
self.c_proj = RowParallelLinear(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
bias=bias,
|
|
input_is_parallel=True,
|
|
)
|
|
self.non_linearity = act_layer()
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x):
|
|
hidden = self.c_fc(x)
|
|
hidden = self.non_linearity(hidden)
|
|
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
|
return self.non_linearity(self.c_proj(hidden))
|
|
|
|
|
|
class PixelShuffleMLP(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
ps_ratio: float,
|
|
input_dim: int,
|
|
output_dim: int = 4096,
|
|
add_fc: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.pixel_shuffle = PixelShuffle(ps_ratio)
|
|
self.mlp = SimpleMLP(
|
|
int(input_dim // (ps_ratio**2)),
|
|
output_dim,
|
|
bias=False,
|
|
dropout=0.0,
|
|
act_layer=nn.GELU,
|
|
)
|
|
self.fc = nn.Identity()
|
|
if add_fc:
|
|
self.fc = ColumnParallelLinear(
|
|
output_dim,
|
|
output_dim,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
|
encoded_patches = self.pixel_shuffle(encoded_patches)
|
|
return self.fc(self.mlp(encoded_patches))
|
|
|
|
|
|
class VisionEmbeddings(torch.nn.Module):
|
|
def __init__(self, args: VisionArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
image_size = args.image_size
|
|
patch_size = args.patch_size
|
|
self.vision_encoder = VisionEncoder(
|
|
image_size=(image_size.height, image_size.width),
|
|
patch_size=(patch_size.height, patch_size.width),
|
|
dim=args.dim,
|
|
layers=args.n_layers,
|
|
heads=args.n_heads,
|
|
mlp_ratio=args.mlp_ratio,
|
|
)
|
|
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
|
|
self.vision_adapter = PixelShuffleMLP(
|
|
ps_ratio=args.pixel_shuffle_ratio,
|
|
input_dim=args.dim,
|
|
output_dim=args.output_dim,
|
|
)
|
|
|
|
self.output_dim = args.output_dim
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(
|
|
self,
|
|
state_dict: Dict[str, Any],
|
|
prefix: str,
|
|
local_metadata: Dict[str, Any],
|
|
strict: bool = True,
|
|
missing_keys: List[str] = None,
|
|
unexpected_keys: List[str] = None,
|
|
error_msgs: List[str] = None,
|
|
return_state_dict: bool = False,
|
|
) -> None:
|
|
original_sd = self.state_dict()
|
|
for k in state_dict:
|
|
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
|
|
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
|
|
|
|
def _get_empty_sequence(self, h):
|
|
return torch.zeros(
|
|
h.shape[0],
|
|
h.shape[1],
|
|
self.output_dim,
|
|
device=h.device,
|
|
dtype=h.dtype,
|
|
)
|
|
|
|
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
|
|
# each image is a tensor of shape [num_tiles, C, H, W]
|
|
def forward(
|
|
self,
|
|
image_batch: List[List[torch.Tensor]],
|
|
image_mask: torch.Tensor,
|
|
h_ref: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
images_flattened = [image for sample in image_batch for image in sample]
|
|
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
|
|
embedding = self.vision_encoder(images_flattened)
|
|
projected_embedding = self.vision_adapter(embedding)
|
|
|
|
h_image = self._get_empty_sequence(h_ref)
|
|
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
|
|
|
|
|
|
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
|
|
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
|
|
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
|
|
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
|
|
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
|
|
|
|
assert not torch.isnan(encoded_patches_proj).any()
|
|
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
|
|
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
|
|
)
|
|
|
|
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
|
|
for index in range(h_image.size(0)):
|
|
encoded_patches_per_sample = encoded_patches_list[index]
|
|
sample_image_mask = image_mask[index]
|
|
|
|
if encoded_patches_per_sample.numel() == 0:
|
|
continue
|
|
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
|
|
-1, encoded_patches_per_sample.size(-1)
|
|
)
|
|
|
|
n_tokens_to_fill = sample_image_mask.sum()
|
|
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
|
|
|
|
h_image[index].masked_scatter_(
|
|
sample_image_mask.expand(-1, h_image.size(-1)),
|
|
encoded_patches_per_sample[:n_tokens_to_fill],
|
|
)
|
|
|
|
return h_image
|