forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			1429 lines
		
	
	
	
		
			50 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1429 lines
		
	
	
	
		
			50 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import logging
 | |
| import math
 | |
| from collections.abc import Callable
 | |
| from functools import partial
 | |
| from typing import Any
 | |
| 
 | |
| import fairscale.nn.model_parallel.initialize as fs_init
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| from fairscale.nn.model_parallel.layers import (
 | |
|     ColumnParallelLinear,
 | |
|     RowParallelLinear,
 | |
|     VocabParallelEmbedding,
 | |
| )
 | |
| from PIL import Image as PIL_Image
 | |
| from torch import Tensor, nn
 | |
| from torch.distributed import _functional_collectives as funcol
 | |
| 
 | |
| from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
 | |
| from .encoder_utils import (
 | |
|     build_encoder_attention_mask,
 | |
|     contract_num_tokens_from_mult8,
 | |
|     expand_num_tokens_to_mult8,
 | |
|     initialize_global_position_embedding_from_local,
 | |
|     resize_global_position_embedding,
 | |
|     resize_local_position_embedding,
 | |
| )
 | |
| from .image_transform import VariableSizeImageTransform
 | |
| from .utils import get_negative_inf_value, to_2tuple
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| MP_SCALE = 8
 | |
| 
 | |
| 
 | |
| def reduce_from_tensor_model_parallel_region(input_):
 | |
|     """All-reduce the input tensor across model parallel group."""
 | |
|     output = funcol.all_reduce(input_, "sum", group=fs_init.get_model_parallel_group())
 | |
|     output = funcol.wait_tensor(output)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def gather_from_tensor_model_parallel_region(input_):
 | |
|     """Gather tensors and concatenate along the last dimension."""
 | |
| 
 | |
|     world_size = fs_init.get_model_parallel_world_size()
 | |
|     # Size and dimension.
 | |
|     last_dim = input_.dim() - 1
 | |
|     rank = fs_init.get_model_parallel_rank()
 | |
| 
 | |
|     tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
 | |
|     tensor_list[rank] = input_
 | |
|     output = funcol.all_gather_tensor(
 | |
|         input_,
 | |
|         gather_dim=last_dim,
 | |
|         group=fs_init.get_model_parallel_group(),
 | |
|     )
 | |
|     output = funcol.wait_tensor(output)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def _get_full_row_masked_out_mask(
 | |
|     attn_bias,
 | |
|     negative_inf_value,
 | |
| ):
 | |
|     """
 | |
|     attn_bias should be a 4D tensor of shape [B, H, S1, S2]
 | |
|     where B is the batch size, H is the number of heads,
 | |
|     and S1/S2 are the sequence lengths. This returns
 | |
|     a 4D tensor of shape [B, H, S1, 1] which stores boolean
 | |
|     values which are 0 if the a full row in the last dimension
 | |
|     contains negative infinity values, otherwise it's 1.
 | |
|     """
 | |
|     return (attn_bias != negative_inf_value).any(dim=-1).type_as(attn_bias)[..., None]
 | |
| 
 | |
| 
 | |
| # Image encoder for inference
 | |
| class LayerNorm(nn.LayerNorm):
 | |
|     """Subclass torch's LayerNorm to handle fp16."""
 | |
| 
 | |
|     def forward(self, x: torch.Tensor):
 | |
|         x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class ColumnParallelConv2dPatch(torch.nn.Module):
 | |
|     """Conv2D Patching layer with model parallelism.
 | |
|     Column parallel over unfolded input.
 | |
|     Arguments:
 | |
|         in_channels: Input channels.
 | |
|         out_channels: Output channels.
 | |
|         kernel_size: Size of convolution kernel.
 | |
|         stride (default 1): Stride for convolution.
 | |
|         bias (default False): Use bias in Conv2d.
 | |
|     Input: (bsz, in_channels, width, height)
 | |
|     Output: (bsz, num_tokens, out_channels)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels: int,
 | |
|         out_channels: int,
 | |
|         kernel_size: int | tuple[int, int],
 | |
|         stride: int | tuple[int, int],
 | |
|         bias: bool | None = False,
 | |
|     ) -> None:
 | |
|         super().__init__()
 | |
|         if isinstance(kernel_size, int):
 | |
|             kernel_size = (kernel_size, kernel_size)
 | |
|         self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
 | |
|         self._linear = ColumnParallelLinear(
 | |
|             in_channels * kernel_size[0] * kernel_size[1],
 | |
|             out_channels,
 | |
|             bias=bias,
 | |
|         )
 | |
| 
 | |
|     def forward(self, x: torch.Tensor) -> torch.Tensor:
 | |
|         x = self._unfold(x)
 | |
|         x = x.permute(0, 2, 1)
 | |
|         x = F.linear(x, self._linear.weight)
 | |
|         x = gather_from_tensor_model_parallel_region(x)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class ImageFeedForward(torch.nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,
 | |
|         hidden_dim: int,
 | |
|         dropout: float,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         # layers
 | |
|         self.c_fc = ColumnParallelLinear(
 | |
|             dim,
 | |
|             hidden_dim,
 | |
|             bias=True,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.c_proj = RowParallelLinear(
 | |
|             hidden_dim,
 | |
|             dim,
 | |
|             bias=True,
 | |
|             input_is_parallel=True,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.non_linearity = act_layer()
 | |
|         self.dropout = dropout
 | |
| 
 | |
|     def forward(self, x):
 | |
|         hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias)
 | |
|         hidden = self.non_linearity(hidden)
 | |
|         hidden = F.linear(hidden, self.c_proj.weight)
 | |
|         hidden = reduce_from_tensor_model_parallel_region(hidden)
 | |
|         hidden += self.c_proj.bias
 | |
|         return hidden
 | |
| 
 | |
| 
 | |
| class ImageAttention(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim,
 | |
|         head_dim,
 | |
|         n_heads,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         world_size = fs_init.get_model_parallel_world_size()
 | |
|         qkvo_replication = 1
 | |
|         if world_size > 16:
 | |
|             qkvo_replication = world_size // 8
 | |
| 
 | |
|         self.n_kv_heads = n_heads
 | |
|         self.n_local_heads = n_heads * qkvo_replication // world_size
 | |
|         self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
 | |
|         self.n_rep = self.n_local_heads // self.n_local_kv_heads
 | |
|         self.head_dim = dim // n_heads
 | |
| 
 | |
|         self.wq = ColumnParallelLinear(
 | |
|             dim,
 | |
|             qkvo_replication * n_heads * self.head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.wk = ColumnParallelLinear(
 | |
|             dim,
 | |
|             qkvo_replication * self.n_kv_heads * self.head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.wv = ColumnParallelLinear(
 | |
|             dim,
 | |
|             qkvo_replication * self.n_kv_heads * self.head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.wo = RowParallelLinear(
 | |
|             qkvo_replication * n_heads * self.head_dim,
 | |
|             dim,
 | |
|             bias=False,
 | |
|             input_is_parallel=True,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.qkvo_replication = qkvo_replication
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         mask: torch.Tensor = None,
 | |
|     ):
 | |
|         xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
 | |
| 
 | |
|         bs, slen, _ = xq.shape
 | |
| 
 | |
|         xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
 | |
|         xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim)
 | |
|         xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim)
 | |
| 
 | |
|         xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)]
 | |
| 
 | |
|         xk = xk.repeat_interleave(self.n_rep, dim=1)
 | |
|         xv = xv.repeat_interleave(self.n_rep, dim=1)
 | |
| 
 | |
|         attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
 | |
| 
 | |
|         attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)
 | |
| 
 | |
|         out = F.linear(attn_output, self.wo.weight)
 | |
|         out = reduce_from_tensor_model_parallel_region(out)
 | |
|         out = out / self.qkvo_replication
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class ImageTransformerBlock(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         d_model: int,
 | |
|         n_head: int,
 | |
|         mlp_ratio: float = 4.0,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|         gated: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         assert d_model % n_head == 0
 | |
|         self.n_heads = n_head
 | |
|         self.head_dim = d_model // self.n_heads
 | |
|         self.attn = ImageAttention(
 | |
|             dim=d_model,
 | |
|             head_dim=self.head_dim,
 | |
|             n_heads=self.n_heads,
 | |
|         )
 | |
|         self.ln_1 = LayerNorm(d_model)
 | |
|         self.mlp = ImageFeedForward(
 | |
|             dim=d_model,
 | |
|             hidden_dim=int(mlp_ratio * d_model),
 | |
|             dropout=0.0,
 | |
|             act_layer=act_layer,
 | |
|         )
 | |
|         self.ln_2 = LayerNorm(d_model)
 | |
|         self.gated = gated
 | |
|         if gated:
 | |
|             self.gate_attn = nn.Parameter(torch.zeros(1))
 | |
|             self.gate_ffn = nn.Parameter(torch.zeros(1))
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         mask: torch.Tensor = None,
 | |
|     ):
 | |
|         _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
 | |
|         _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
 | |
|         x = x + _gate_attn * self.attn(self.ln_1(x), mask=mask)
 | |
|         x = x + _gate_ffn * self.mlp(self.ln_2(x))
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class ImageTransformer(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         width: int,
 | |
|         layers: int,
 | |
|         heads: int,
 | |
|         mlp_ratio: float = 4.0,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|         gated: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.width = width
 | |
|         self.layers = layers
 | |
|         self.resblocks = nn.ModuleList(
 | |
|             [
 | |
|                 ImageTransformerBlock(
 | |
|                     d_model=width,
 | |
|                     n_head=heads,
 | |
|                     mlp_ratio=mlp_ratio,
 | |
|                     act_layer=act_layer,
 | |
|                     gated=gated,
 | |
|                 )
 | |
|                 for _ in range(self.layers)
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|     def forward(self, x: torch.Tensor, return_intermediate=None, mask=None):
 | |
|         out = []
 | |
|         for idx, r in enumerate(self.resblocks):
 | |
|             if return_intermediate is not None and idx in return_intermediate:
 | |
|                 out.append(x)
 | |
|             x = r(x, mask=mask)
 | |
|         if return_intermediate is not None:
 | |
|             return x, torch.stack(out, dim=-1)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class VisionEncoder(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         max_num_tiles: int,
 | |
|         ckpt_path: str = None,
 | |
|         image_size: int = 224,
 | |
|         patch_size: int = 14,
 | |
|         width: int = 1280,
 | |
|         layers: int = 32,
 | |
|         heads: int = 16,
 | |
|         mlp_ratio: float = 4.0,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|         in_channels: int = 3,
 | |
|         load_ckpt: bool = False,
 | |
|         n_global_layers: int = 2,
 | |
|         global_model: bool = False,
 | |
|         return_intermediate=None,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.global_model = global_model
 | |
|         self.return_intermediate = return_intermediate
 | |
|         self.max_num_tiles = max_num_tiles
 | |
|         self.image_size = to_2tuple(image_size)
 | |
|         self.patch_size = to_2tuple(patch_size)
 | |
|         self.grid_size = (
 | |
|             self.image_size[0] // self.patch_size[0],
 | |
|             self.image_size[1] // self.patch_size[1],
 | |
|         )
 | |
|         self.conv1 = ColumnParallelConv2dPatch(
 | |
|             in_channels=in_channels,
 | |
|             out_channels=width,
 | |
|             kernel_size=patch_size,
 | |
|             stride=patch_size,
 | |
|             bias=False,
 | |
|         )
 | |
|         scale = width**-0.5
 | |
|         self.class_embedding = nn.Parameter(scale * torch.randn(width))
 | |
|         self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
 | |
|         self.ln_post = LayerNorm(width)
 | |
|         self.ln_pre = LayerNorm(width)
 | |
|         self.transformer = ImageTransformer(width, layers, heads, mlp_ratio, act_layer=act_layer)
 | |
|         # pre and post tile position embedding
 | |
|         self.global_transformer = ImageTransformer(
 | |
|             width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True
 | |
|         )
 | |
|         # pre and post tile position embedding
 | |
|         self.pre_tile_pos_embed = TilePositionEmbedding(
 | |
|             num_tiles=max_num_tiles,
 | |
|             width=width,
 | |
|             gated=True,
 | |
|         )
 | |
|         self.post_tile_pos_embed = TilePositionEmbedding(
 | |
|             num_tiles=max_num_tiles,
 | |
|             width=width,
 | |
|             gated=True,
 | |
|         )
 | |
|         self.gated_positional_embedding = nn.Parameter(
 | |
|             scale
 | |
|             * torch.randn(
 | |
|                 max_num_tiles,
 | |
|                 max_num_tiles,
 | |
|                 self.grid_size[0] * self.grid_size[1] + 1,
 | |
|                 width,
 | |
|             )
 | |
|         )
 | |
|         self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1))
 | |
| 
 | |
|         self._register_load_state_dict_pre_hook(self.load_hook)
 | |
| 
 | |
|     def load_hook(
 | |
|         self,
 | |
|         state_dict: dict[str, Any],
 | |
|         prefix: str,
 | |
|         local_metadata: dict[str, Any],
 | |
|         strict: bool = True,
 | |
|         missing_keys: list[str] = None,
 | |
|         unexpected_keys: list[str] = None,
 | |
|         error_msgs: list[str] = None,
 | |
|         return_state_dict: bool = False,
 | |
|     ) -> None:
 | |
|         orig_pos_embed = state_dict.get(prefix + "positional_embedding")
 | |
|         if orig_pos_embed is not None:
 | |
|             new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size)
 | |
|             state_dict[prefix + "positional_embedding"] = new_pos_embed
 | |
|         if hasattr(self, "gated_positional_embedding"):
 | |
|             if prefix + "gated_positional_embedding" not in state_dict:
 | |
|                 # resize positional_embedding to fit the new grid size
 | |
|                 global_pos_embed = initialize_global_position_embedding_from_local(
 | |
|                     new_pos_embed,
 | |
|                     self.grid_size,
 | |
|                     self.max_num_tiles,
 | |
|                     self.max_num_tiles,
 | |
|                 )
 | |
|                 state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
 | |
|                 state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
 | |
|                 logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
 | |
|             else:
 | |
|                 global_pos_embed = resize_global_position_embedding(
 | |
|                     state_dict[prefix + "gated_positional_embedding"],
 | |
|                     self.grid_size,
 | |
|                     self.max_num_tiles,
 | |
|                     self.max_num_tiles,
 | |
|                 )
 | |
|                 logger.info(
 | |
|                     f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
 | |
|                 )
 | |
|                 state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
 | |
|         if return_state_dict:
 | |
|             return state_dict
 | |
| 
 | |
|     def apply_positional_embedding(self, x, ar):
 | |
|         # apply regular position embedding
 | |
|         bsz, num_chunks, num_tokens, dim = x.shape
 | |
|         x = x.view(bsz * num_chunks, num_tokens, dim)
 | |
|         x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh())
 | |
|         x = x.view(bsz, num_chunks, num_tokens, dim)
 | |
|         for idx, arx in enumerate(ar):
 | |
|             _pos_embed = self.gated_positional_embedding[: arx[0], : arx[1]]
 | |
|             _pos_embed = _pos_embed.reshape(arx[0] * arx[1], *_pos_embed.shape[2:])
 | |
|             x[idx, : arx[0] * arx[1]] += _pos_embed * self.gated_positional_embedding_gate.tanh()
 | |
|         return x
 | |
| 
 | |
|     def apply_class_embedding(self, x):
 | |
|         x = torch.cat(
 | |
|             [
 | |
|                 self.class_embedding.to(x.dtype)
 | |
|                 + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
 | |
|                 x,
 | |
|             ],
 | |
|             dim=1,
 | |
|         )  # shape = [*, grid ** 2 + 1, width]
 | |
|         return x
 | |
| 
 | |
|     def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor:
 | |
|         if images.ndim == 5:
 | |
|             num_concurrent_media = 1
 | |
|             bsz, num_chunks, nch, w, h = images.shape
 | |
|         else:
 | |
|             bsz, num_concurrent_media, num_chunks, nch, w, h = images.shape
 | |
| 
 | |
|         images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h)
 | |
|         ar = ar.reshape(bsz * num_concurrent_media, 2)
 | |
| 
 | |
|         # patch embedding
 | |
|         x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h)
 | |
|         x = self.conv1(x)  # shape = [*, width, grid ** 2]
 | |
|         _, ntok, dim = x.shape
 | |
|         x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
 | |
| 
 | |
|         # tile embeddings
 | |
|         x = self.pre_tile_pos_embed(x, ar)
 | |
|         x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
 | |
| 
 | |
|         # apply cls token
 | |
|         x = self.apply_class_embedding(x)
 | |
|         ntok += 1
 | |
| 
 | |
|         # apply position embeddings
 | |
|         x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
 | |
|         x = self.apply_positional_embedding(x, ar)
 | |
| 
 | |
|         x = self.ln_pre(x)
 | |
|         npad, attn_mask = 0, None
 | |
|         x, npad = expand_num_tokens_to_mult8(x)
 | |
|         attn_mask = build_encoder_attention_mask(x, ar, ntok, num_chunks, 1)
 | |
|         x = x.view(bsz * num_concurrent_media, -1, dim)
 | |
|         x, int_x = self.transformer(x, return_intermediate=self.return_intermediate, mask=attn_mask)
 | |
| 
 | |
|         x = self.ln_post(x)
 | |
|         x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
 | |
|         x = self.post_tile_pos_embed(x, ar)
 | |
|         x = x.reshape(bsz * num_concurrent_media, num_chunks * (ntok + npad), dim)
 | |
|         x = self.global_transformer(x, mask=attn_mask)
 | |
|         x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
 | |
|         x = contract_num_tokens_from_mult8(x, npad)
 | |
| 
 | |
|         # adding back intermediate layer outputs
 | |
|         x = x.reshape(bsz, num_concurrent_media, num_chunks, ntok, dim)
 | |
|         int_x = int_x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, -1)
 | |
|         int_x = contract_num_tokens_from_mult8(int_x, npad)
 | |
|         int_x = int_x.reshape(bsz, num_concurrent_media, num_chunks, ntok, -1)
 | |
|         x = torch.cat([x, int_x], dim=-1)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class Attention(nn.Module):
 | |
|     """Multi-head attention module."""
 | |
| 
 | |
|     def __init__(self, args: ModelArgs):
 | |
|         """
 | |
|         Initialize the Attention module.
 | |
|         Args:
 | |
|             args (ModelArgs): Model configuration parameters.
 | |
|         Attributes:
 | |
|             n_kv_heads (int): Number of key and value heads.
 | |
|             n_local_heads (int): Number of local query heads.
 | |
|             n_local_kv_heads (int): Number of local key and value heads.
 | |
|             n_rep (int): Number of repetitions for local heads.
 | |
|             head_dim (int): Dimension size of each attention head.
 | |
|             wq (ColumnParallelLinear): Linear transformation for queries.
 | |
|             wk (ColumnParallelLinear): Linear transformation for keys.
 | |
|             wv (ColumnParallelLinear): Linear transformation for values.
 | |
|             wo (RowParallelLinear): Linear transformation for output.
 | |
|             cache_k (torch.Tensor): Cached keys for attention.
 | |
|             cache_v (torch.Tensor): Cached values for attention.
 | |
|         """
 | |
|         super().__init__()
 | |
|         world_size = fs_init.get_model_parallel_world_size()
 | |
|         replication_factor = 1
 | |
|         if world_size > 8:
 | |
|             replication_factor = world_size // MP_SCALE
 | |
| 
 | |
|         self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
 | |
|         self.n_kv_heads *= replication_factor
 | |
| 
 | |
|         self.n_local_heads = args.n_heads // world_size
 | |
|         self.n_local_kv_heads = self.n_kv_heads // world_size
 | |
|         self.n_rep = self.n_local_heads // self.n_local_kv_heads
 | |
|         self.head_dim = args.dim // args.n_heads
 | |
|         self.max_seq_len = args.max_seq_len
 | |
| 
 | |
|         self.wq = ColumnParallelLinear(
 | |
|             args.dim,
 | |
|             args.n_heads * self.head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.wk = ColumnParallelLinear(
 | |
|             args.dim,
 | |
|             self.n_kv_heads * self.head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.wv = ColumnParallelLinear(
 | |
|             args.dim,
 | |
|             self.n_kv_heads * self.head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.wo = RowParallelLinear(
 | |
|             args.n_heads * self.head_dim,
 | |
|             args.dim,
 | |
|             bias=False,
 | |
|             input_is_parallel=True,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.n_heads = args.n_heads
 | |
| 
 | |
|     def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
 | |
|         cache_shape = (
 | |
|             max_batch_size,
 | |
|             self.max_seq_len,
 | |
|             self.n_local_kv_heads,
 | |
|             self.head_dim,
 | |
|         )
 | |
|         self.register_buffer(
 | |
|             "key_cache",
 | |
|             torch.zeros(
 | |
|                 cache_shape,
 | |
|                 dtype=dtype,
 | |
|             ),
 | |
|             persistent=False,
 | |
|         )
 | |
|         self.register_buffer(
 | |
|             "value_cache",
 | |
|             torch.zeros(
 | |
|                 cache_shape,
 | |
|                 dtype=dtype,
 | |
|             ),
 | |
|             persistent=False,
 | |
|         )
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         mask: torch.Tensor,
 | |
|         freqs_cis: torch.Tensor,
 | |
|         position_ids: torch.LongTensor,
 | |
|     ):
 | |
|         self.key_cache = self.key_cache.to(x.device)
 | |
|         self.value_cache = self.value_cache.to(x.device)
 | |
| 
 | |
|         xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
 | |
| 
 | |
|         bs, slen, _ = xq.shape
 | |
| 
 | |
|         xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
 | |
|         xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim)
 | |
|         xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim)
 | |
| 
 | |
|         xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
 | |
| 
 | |
|         self.key_cache[:bs, position_ids, ...] = xk
 | |
|         self.value_cache[:bs, position_ids, ...] = xv
 | |
| 
 | |
|         # TODO: we can avoid slicing on first dimension by always padding to max_batch_size()
 | |
|         xk = self.key_cache[:bs, ...]
 | |
|         xv = self.value_cache[:bs, ...]
 | |
| 
 | |
|         xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)]
 | |
| 
 | |
|         xk = xk.repeat_interleave(self.n_rep, dim=1)
 | |
|         xv = xv.repeat_interleave(self.n_rep, dim=1)
 | |
| 
 | |
|         attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
 | |
| 
 | |
|         attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)
 | |
| 
 | |
|         out = F.linear(attn_output, self.wo.weight)
 | |
|         out = reduce_from_tensor_model_parallel_region(out)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class FeedForward(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,
 | |
|         hidden_dim: int,
 | |
|         multiple_of: int,
 | |
|         ffn_dim_multiplier: float | None,
 | |
|     ):
 | |
|         """
 | |
|         Initialize the FeedForward module.
 | |
|         Args:
 | |
|             dim (int): Input dimension.
 | |
|             hidden_dim (int): Hidden dimension of the feedforward layer.
 | |
|             multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
 | |
|             ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
 | |
|         Attributes:
 | |
|             w1 (ColumnParallelLinear): Linear transformation for the first layer.
 | |
|             w2 (RowParallelLinear): Linear transformation for the second layer.
 | |
|             w3 (ColumnParallelLinear): Linear transformation for the third layer.
 | |
|         """
 | |
|         super().__init__()
 | |
|         hidden_dim = int(2 * hidden_dim / 3)
 | |
|         # custom dim factor multiplier
 | |
|         if ffn_dim_multiplier is not None:
 | |
|             hidden_dim = int(ffn_dim_multiplier * hidden_dim)
 | |
|         hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
 | |
| 
 | |
|         self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
 | |
|         self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
 | |
|         self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x1, x3 = [F.linear(x, w) for w in [self.w1.weight, self.w3.weight]]
 | |
|         x1 = F.silu(x1)
 | |
|         x_in = x1 * x3
 | |
|         out = F.linear(x_in, self.w2.weight)
 | |
|         out = reduce_from_tensor_model_parallel_region(out)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class TransformerBlock(nn.Module):
 | |
|     def __init__(self, layer_id: int, args: ModelArgs):
 | |
|         """
 | |
|         Initialize a TransformerBlock.
 | |
|         Args:
 | |
|             layer_id (int): Identifier for the layer.
 | |
|             args (ModelArgs): Model configuration parameters.
 | |
|         Attributes:
 | |
|             n_heads (int): Number of attention heads.
 | |
|             dim (int): Dimension size of the model.
 | |
|             head_dim (int): Dimension size of each attention head.
 | |
|             attention (Attention): Attention module.
 | |
|             feed_forward (FeedForward): FeedForward module.
 | |
|             layer_id (int): Identifier for the layer.
 | |
|             attention_norm (RMSNorm): Layer normalization for attention output.
 | |
|             ffn_norm (RMSNorm): Layer normalization for feedforward output.
 | |
|         """
 | |
|         super().__init__()
 | |
|         self.n_heads = args.n_heads
 | |
|         self.dim = args.dim
 | |
|         self.head_dim = args.dim // args.n_heads
 | |
|         self.attention = Attention(args)
 | |
|         self.feed_forward = FeedForward(
 | |
|             dim=args.dim,
 | |
|             hidden_dim=4 * args.dim,
 | |
|             multiple_of=args.multiple_of,
 | |
|             ffn_dim_multiplier=args.ffn_dim_multiplier,
 | |
|         )
 | |
|         self.layer_id = layer_id
 | |
|         self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
 | |
|         self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 | |
| 
 | |
|     def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
 | |
|         self.attention.setup_cache(max_batch_size, dtype)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         freqs_cis: torch.Tensor,
 | |
|         mask: torch.Tensor,
 | |
|         position_ids: torch.LongTensor,
 | |
|     ) -> torch.Tensor:
 | |
|         """
 | |
|         Perform a forward pass through the TransformerBlock.
 | |
|         Args:
 | |
|             x (torch.Tensor): Input tensor.
 | |
|             start_pos (int): Starting position for attention caching.
 | |
|             freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
 | |
|             mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
 | |
|         Returns:
 | |
|             torch.Tensor: Output tensor after applying attention and feedforward layers.
 | |
|         """
 | |
|         h = self.attention.forward(
 | |
|             x=self.attention_norm(x),
 | |
|             freqs_cis=freqs_cis,
 | |
|             mask=mask,
 | |
|             position_ids=position_ids,
 | |
|         )
 | |
|         h = h + x
 | |
|         out = h + self.feed_forward.forward(self.ffn_norm(h))
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class TilePositionEmbedding(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_tiles: int,
 | |
|         width: int,
 | |
|         gated: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.num_tiles = num_tiles
 | |
|         self.width = width
 | |
|         self.embedding = nn.Parameter(torch.randn(num_tiles, num_tiles, 1, width) / math.sqrt(width))
 | |
|         self.gated = gated
 | |
|         if gated:
 | |
|             self.gate = nn.Parameter(torch.zeros(1))
 | |
| 
 | |
|         self._register_load_state_dict_pre_hook(self.load_hook)
 | |
| 
 | |
|     def load_hook(
 | |
|         self,
 | |
|         state_dict,
 | |
|         prefix,
 | |
|         local_metadata,
 | |
|         strict,
 | |
|         missing_keys,
 | |
|         unexpected_keys,
 | |
|         error_msgs,
 | |
|     ):
 | |
|         # load the weights from the checkpoint
 | |
|         embed = state_dict.get(prefix + "embedding")
 | |
|         if embed is not None:
 | |
|             # reshape the weights to the correct shape
 | |
|             nt_old, nt_old, _, w = embed.shape
 | |
|             logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
 | |
|             embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
 | |
|             # assign the weights to the module
 | |
|             state_dict[prefix + "embedding"] = embed_new
 | |
| 
 | |
|     @staticmethod
 | |
|     def _dynamic_resize(embed: torch.Tensor, num_tiles: int):
 | |
|         nt_old, nt_old, _, w = embed.shape
 | |
|         embed = embed.permute(2, 3, 0, 1)
 | |
| 
 | |
|         embed_new = F.interpolate(
 | |
|             embed,
 | |
|             size=(num_tiles, num_tiles),
 | |
|             mode="bilinear",
 | |
|             align_corners=True,
 | |
|         )
 | |
|         # reshape the weights to the correct shape
 | |
|         embed_new = embed_new.permute(2, 3, 0, 1)
 | |
|         return embed_new
 | |
| 
 | |
|     def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None):
 | |
|         embed = self.embedding
 | |
|         if num_tiles is None:
 | |
|             num_tiles = self.num_tiles
 | |
|         elif num_tiles > self.num_tiles:
 | |
|             embed = TilePositionEmbedding._dynamic_resize(self.embedding, num_tiles)
 | |
|         out_pos_embed = torch.zeros(x.shape[0], num_tiles, 1, self.width, device=x.device, dtype=x.dtype)
 | |
|         for idx, arx in enumerate(ar):
 | |
|             h, w = arx
 | |
|             out_pos_embed[idx, : w * h] = embed[:h, :w].reshape(w * h, 1, self.width)
 | |
|         if self.gated:
 | |
|             out_pos_embed = out_pos_embed * self.gate.tanh()
 | |
|         x = x + out_pos_embed
 | |
|         return x
 | |
| 
 | |
| 
 | |
| def _noinit(x):
 | |
|     return x
 | |
| 
 | |
| 
 | |
| class CrossAttention(torch.nn.Module):
 | |
|     """Cross attention layer with model-parallel attention layers."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,
 | |
|         head_dim: int,
 | |
|         n_heads: int,
 | |
|         n_kv_heads: int,
 | |
|         norm_eps: float,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.world_size = fs_init.get_model_parallel_world_size()
 | |
|         replication_factor = 1
 | |
|         if self.world_size > 8:
 | |
|             replication_factor = self.world_size // MP_SCALE
 | |
|         n_kv_heads *= replication_factor
 | |
| 
 | |
|         assert n_heads % n_kv_heads == 0
 | |
| 
 | |
|         self.wq = ColumnParallelLinear(
 | |
|             dim,
 | |
|             n_heads * head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=_noinit,
 | |
|         )
 | |
| 
 | |
|         self.wk = ColumnParallelLinear(
 | |
|             dim,
 | |
|             n_kv_heads * head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=_noinit,
 | |
|         )
 | |
|         self.wv = ColumnParallelLinear(
 | |
|             dim,
 | |
|             n_kv_heads * head_dim,
 | |
|             bias=False,
 | |
|             gather_output=False,
 | |
|             init_method=_noinit,
 | |
|         )
 | |
|         self.wo = RowParallelLinear(
 | |
|             n_heads * head_dim,
 | |
|             dim,
 | |
|             bias=False,
 | |
|             input_is_parallel=True,
 | |
|             init_method=_noinit,
 | |
|         )
 | |
| 
 | |
|         self.n_heads = n_heads
 | |
|         self.head_dim = head_dim
 | |
|         self.n_kv_heads = n_kv_heads
 | |
| 
 | |
|         self.q_norm = RMSNorm(
 | |
|             self.head_dim,
 | |
|             eps=norm_eps,
 | |
|         )
 | |
|         self.k_norm = RMSNorm(
 | |
|             self.head_dim,
 | |
|             eps=norm_eps,
 | |
|         )
 | |
| 
 | |
|         # cross-attention heads are model parallel similar to
 | |
|         # self-attention, and we also use the identical KV head
 | |
|         # combination to ensure parity with the corresponding
 | |
|         # trunk LLM (i.e., group query attention) -- @dubeya
 | |
|         # local heads
 | |
|         assert self.n_heads % self.n_kv_heads == 0
 | |
|         assert self.n_heads % self.world_size == 0
 | |
|         assert self.n_kv_heads % self.world_size == 0
 | |
|         self.n_local_heads = self.n_heads // self.world_size
 | |
|         self.n_local_kv_heads = self.n_kv_heads // self.world_size
 | |
|         self.n_rep = self.n_local_heads // self.n_local_kv_heads
 | |
| 
 | |
|     def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
 | |
|         bsz = xattn_tokens.shape[0]
 | |
|         xk = self.wk(xattn_tokens)
 | |
|         xv = self.wv(xattn_tokens)
 | |
| 
 | |
|         _, seqlen_y, _ = xk.shape
 | |
| 
 | |
|         xk = xk.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim)
 | |
|         xv = xv.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim)
 | |
| 
 | |
|         xk, xv = [tensor.transpose(1, 2) for tensor in (xk, xv)]
 | |
| 
 | |
|         # repeat k/v heads if n_kv_heads < n_heads
 | |
|         xk = xk.repeat_interleave(self.n_rep, dim=1)
 | |
|         xv = xv.repeat_interleave(self.n_rep, dim=1)
 | |
| 
 | |
|         xk = self.k_norm(xk)
 | |
| 
 | |
|         return torch.stack([xk, xv])
 | |
| 
 | |
|     def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
 | |
|         return self._compute_xattn_kv_cache(xattn_tokens)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         xattn_mask: torch.Tensor,
 | |
|         full_text_row_masked_out_mask: torch.Tensor,
 | |
|         xattn_cache: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         xq = F.linear(x, self.wq.weight)
 | |
|         bsz, seqlen, _ = x.shape
 | |
| 
 | |
|         xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
 | |
|         xq = self.q_norm(xq)
 | |
|         xq = xq.transpose(1, 2)
 | |
| 
 | |
|         xk, xv = xattn_cache
 | |
| 
 | |
|         output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=xattn_mask, dropout_p=0.0)
 | |
|         output = output * full_text_row_masked_out_mask
 | |
|         output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1)
 | |
| 
 | |
|         out = F.linear(output, self.wo.weight)
 | |
|         out = reduce_from_tensor_model_parallel_region(out)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class CrossAttentionTransformerBlock(torch.nn.Module):
 | |
|     """Cross-attention transformer block with tanh-gated attention and feedforward."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         args: ModelArgs,
 | |
|         layer_id: int,
 | |
|         no_ffn: bool = False,
 | |
|     ) -> None:
 | |
|         super().__init__()
 | |
|         self.layer_id = layer_id
 | |
|         self.n_heads = args.n_heads
 | |
|         self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
 | |
|         self.dim = args.dim
 | |
|         self.head_dim = args.dim // args.n_heads
 | |
|         self.attention = CrossAttention(
 | |
|             dim=args.dim,
 | |
|             head_dim=self.head_dim,
 | |
|             n_heads=self.n_heads,
 | |
|             n_kv_heads=self.n_kv_heads,
 | |
|             norm_eps=args.norm_eps,
 | |
|         )
 | |
| 
 | |
|         self.attention_norm = RMSNorm(
 | |
|             args.dim,
 | |
|             eps=args.norm_eps,
 | |
|         )
 | |
|         self.gate_attn = torch.nn.Parameter(torch.zeros(1))
 | |
| 
 | |
|         self.feed_forward = FeedForward(
 | |
|             dim=args.dim,
 | |
|             hidden_dim=4 * args.dim,
 | |
|             ffn_dim_multiplier=args.ffn_dim_multiplier,
 | |
|             multiple_of=args.multiple_of,
 | |
|         )
 | |
|         self.ffn_norm = RMSNorm(
 | |
|             args.dim,
 | |
|             eps=args.norm_eps,
 | |
|         )
 | |
|         self.gate_ffwd = torch.nn.Parameter(torch.zeros(1))
 | |
| 
 | |
|         self.no_ffn = no_ffn
 | |
| 
 | |
|     def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
 | |
|         return self.attention.compute_xattn_kv_cache(xattn_tokens)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         xattn_mask: torch.Tensor,
 | |
|         full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
 | |
|         xattn_cache: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         _attn_out = self.attention(
 | |
|             x=self.attention_norm(x),
 | |
|             xattn_mask=xattn_mask,
 | |
|             xattn_cache=xattn_cache,
 | |
|             full_text_row_masked_out_mask=full_text_row_masked_out_mask,
 | |
|         )
 | |
|         h = x + self.gate_attn.tanh() * _attn_out
 | |
|         _ffn = self.feed_forward(self.ffn_norm(h))
 | |
|         _ffn = full_text_row_masked_out_mask[:, 0] * _ffn  # type: ignore
 | |
|         h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn)
 | |
|         return h
 | |
| 
 | |
| 
 | |
| class DummyCrossAttentionTransformerBlock:
 | |
|     """Dummy cross-attention transformer block with tanh-gated attention and feedforward."""
 | |
| 
 | |
|     def __call__(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         *args,
 | |
|         **kwargs,
 | |
|     ) -> torch.Tensor:
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class DummySelfAttentionTransformerBlock:
 | |
|     """Dummy self-attention transformer block"""
 | |
| 
 | |
|     def __call__(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         *args,
 | |
|         **kwargs,
 | |
|     ) -> torch.Tensor:
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class CrossAttentionTransformerVision(torch.nn.Module):
 | |
|     def __init__(self, args: ModelArgs) -> None:
 | |
|         super().__init__()
 | |
|         return_intermediate = "3,7,15,23,30"
 | |
|         self.vision_input_dim = 1280
 | |
|         self.image_res = args.vision_chunk_size
 | |
|         self.max_num_chunks = args.vision_max_num_chunks
 | |
|         if return_intermediate is not None:
 | |
|             return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
 | |
|             self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
 | |
|         self.patch_size = 14
 | |
|         self.vision_encoder = VisionEncoder(
 | |
|             max_num_tiles=4,
 | |
|             image_size=args.vision_chunk_size,
 | |
|             patch_size=self.patch_size,
 | |
|             n_global_layers=8,
 | |
|             global_model=True,
 | |
|             return_intermediate=return_intermediate,
 | |
|         )
 | |
|         # vision token projection
 | |
|         self.vision_projection = ColumnParallelLinear(
 | |
|             self.vision_input_dim,
 | |
|             args.dim,
 | |
|             bias=True,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
| 
 | |
|     def forward(self, images: torch.Tensor, aspect_ratios: torch.Tensor) -> torch.Tensor:
 | |
|         # vision_tokens: (B, T, D)
 | |
|         # aspect_ratios: (B, T)
 | |
|         # h: (B, T, D)
 | |
|         vision_tokens = self.vision_encoder(images.to(dtype=torch.get_default_dtype()), aspect_ratios)
 | |
| 
 | |
|         vision_tokens = F.linear(vision_tokens, self.vision_projection.weight, self.vision_projection.bias)
 | |
|         vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens)
 | |
|         return vision_tokens
 | |
| 
 | |
| 
 | |
| class CrossAttentionTransformerText(torch.nn.Module):
 | |
|     INFERENCE_IMAGE_TOKEN_ID = 128010
 | |
| 
 | |
|     def __init__(self, args: ModelArgs) -> None:
 | |
|         super().__init__()
 | |
|         self.world_size = fs_init.get_model_parallel_world_size()
 | |
|         assert args.vocab_size > 0
 | |
|         self.vocab_size = args.vocab_size
 | |
|         self.n_layers = args.n_layers
 | |
|         self.dim = args.dim
 | |
|         self.head_dim = args.dim // args.n_heads
 | |
|         self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
 | |
|         self.n_local_kv_heads = self.n_kv_heads // self.world_size
 | |
|         assert self.vocab_size % self.world_size == 0
 | |
|         self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
 | |
|         self.pos_embeddings = None
 | |
|         # final norm layer (not necessary for post-norm)
 | |
|         self.norm = RMSNorm(args.dim, eps=args.norm_eps)
 | |
| 
 | |
|         # output layer
 | |
|         self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
 | |
| 
 | |
|         self.n_llama_layers = args.n_layers
 | |
|         self.model_dim = args.dim
 | |
| 
 | |
|         # BLOCKS
 | |
| 
 | |
|         self.fusion_schedule = self._init_fusion_schedule(args.vision_num_cross_attention_layers)
 | |
|         self.learnable_embedding = VocabParallelEmbedding(
 | |
|             max(fs_init.get_model_parallel_world_size(), 8),
 | |
|             args.dim,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.num_frozen_embeddings = self.tok_embeddings.num_embeddings
 | |
|         self._thresh = self.num_frozen_embeddings - 1
 | |
| 
 | |
|         # transformer blocks
 | |
|         self.layers = torch.nn.ModuleList()
 | |
|         self.cross_attention_layers = torch.nn.ModuleList()
 | |
|         for i in range(args.n_layers):
 | |
|             layer_id = i
 | |
|             block = TransformerBlock(args=args, layer_id=layer_id)
 | |
|             self.layers.append(block)
 | |
|             if layer_id in self.fusion_schedule:
 | |
|                 xa_layer_id = self.fusion_schedule.index(layer_id) + args.n_layers
 | |
|                 block = CrossAttentionTransformerBlock(
 | |
|                     args,
 | |
|                     layer_id=xa_layer_id,
 | |
|                 )
 | |
|                 self.cross_attention_layers.append(block)
 | |
| 
 | |
|         # add xattn and dummy layers to avoid conditionals in forward()
 | |
|         self.text_and_xattn_layers = []
 | |
| 
 | |
|         for idx, layer in enumerate(self.layers):
 | |
|             if idx in self.fusion_schedule:
 | |
|                 xattn_layer_idx = self.fusion_schedule.index(idx)
 | |
|                 xattn_layer = self.cross_attention_layers[xattn_layer_idx]
 | |
|             else:
 | |
|                 xattn_layer_idx = 0
 | |
|                 xattn_layer = DummyCrossAttentionTransformerBlock()
 | |
| 
 | |
|             self.text_and_xattn_layers.append(
 | |
|                 (
 | |
|                     layer,
 | |
|                     xattn_layer,
 | |
|                     xattn_layer_idx,
 | |
|                 )
 | |
|             )
 | |
|         self.freqs_cis = precompute_freqs_cis(
 | |
|             args.dim // args.n_heads,
 | |
|             args.max_seq_len * 2,
 | |
|             args.rope_theta,
 | |
|             args.use_scaled_rope,
 | |
|         )
 | |
| 
 | |
|         self.args = args
 | |
|         self.cache_is_setup = False
 | |
|         self.max_seq_len = args.max_seq_len
 | |
| 
 | |
|     def _init_fusion_schedule(
 | |
|         self,
 | |
|         num_layers: int,
 | |
|     ) -> list[int]:
 | |
|         llama_layers = list(range(self.n_llama_layers))
 | |
| 
 | |
|         # uniformly spread the layers
 | |
|         k = math.ceil(len(llama_layers) / num_layers)
 | |
|         return llama_layers[::-1][::k][:num_layers][::-1]
 | |
| 
 | |
|     def get_partially_trainable_embedding(self, x):
 | |
|         xz = torch.zeros_like(x, device=x.device)
 | |
|         oz = torch.ones_like(x, device=x.device)
 | |
|         x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device))
 | |
|         x_new = torch.maximum(x, torch.tensor(self._thresh + 1, device=x.device)) - self.num_frozen_embeddings
 | |
| 
 | |
|         mask_orig = torch.where(x >= self.num_frozen_embeddings, xz, oz).unsqueeze(-1)
 | |
|         mask_new = torch.where(x < self.num_frozen_embeddings, xz, oz).unsqueeze(-1)
 | |
| 
 | |
|         x_orig = self.tok_embeddings(x_orig)
 | |
|         x_new = self.learnable_embedding(x_new).type_as(x_orig)
 | |
|         return x_orig * mask_orig.type_as(x_orig) + x_new * mask_new.type_as(x_new)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         position_ids: torch.LongTensor,
 | |
|         h: torch.Tensor,
 | |
|         xattn_mask: torch.Tensor,
 | |
|         full_text_row_masked_out_mask: torch.Tensor,
 | |
|         xattn_caches: torch.Tensor,
 | |
|         text_only_inference: bool = False,
 | |
|     ):
 | |
|         assert self.cache_is_setup, "Please set up cache before calling forward"
 | |
|         self.mask_cache = self.mask_cache.to(h.device)
 | |
|         self.freqs_cis = self.freqs_cis.to(h.device)
 | |
|         mask = self.mask_cache.index_select(2, position_ids)
 | |
|         freqs_cis = self.freqs_cis.index_select(0, position_ids)
 | |
| 
 | |
|         for (
 | |
|             layer,
 | |
|             xattn_layer,
 | |
|             xattn_layer_idx,
 | |
|         ) in self.text_and_xattn_layers:
 | |
|             if not text_only_inference:
 | |
|                 h = xattn_layer(
 | |
|                     x=h,
 | |
|                     xattn_mask=xattn_mask,
 | |
|                     xattn_cache=xattn_caches[xattn_layer_idx],
 | |
|                     full_text_row_masked_out_mask=full_text_row_masked_out_mask,
 | |
|                 )
 | |
|             h = layer(
 | |
|                 x=h,
 | |
|                 mask=mask,
 | |
|                 freqs_cis=freqs_cis,
 | |
|                 position_ids=position_ids,
 | |
|             )
 | |
| 
 | |
|         h = self.norm(h)
 | |
| 
 | |
|         output = F.linear(h, self.output.weight)
 | |
|         output = gather_from_tensor_model_parallel_region(output)
 | |
|         return output.float()
 | |
| 
 | |
|     def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
 | |
|         # Set up the text kv caches
 | |
|         ones = torch.ones(
 | |
|             (self.max_seq_len, self.max_seq_len),
 | |
|             dtype=torch.bool,
 | |
|             device=device,
 | |
|         )
 | |
|         self.register_buffer(
 | |
|             "mask_cache",
 | |
|             torch.tril(
 | |
|                 ones,
 | |
|             )
 | |
|             .unsqueeze(0)
 | |
|             .unsqueeze(0),
 | |
|             persistent=False,
 | |
|         )
 | |
|         for layer in self.layers:
 | |
|             layer.setup_cache(max_batch_size, dtype=dtype)
 | |
|         self.cache_is_setup = True
 | |
| 
 | |
|     def _get_xattn_mask(
 | |
|         self,
 | |
|         num_tokens,
 | |
|         text_device,
 | |
|         text_dtype,
 | |
|         vision_tokens,
 | |
|         cross_attention_masks,
 | |
|     ) -> tuple[Tensor, Tensor]:
 | |
|         assert vision_tokens is not None, "Vision tokens must be provided"
 | |
|         vision_seqlen = vision_tokens.shape[3]
 | |
|         assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
 | |
|             f"Mismatch in number of images given and number of masks given {vision_tokens.shape} {cross_attention_masks.shape}"
 | |
|         )
 | |
|         assert vision_tokens.shape[2] == cross_attention_masks.shape[3], (
 | |
|             f"Vision tokens shape {vision_tokens.shape} mismatch with xattn shape {cross_attention_masks.shape}"
 | |
|         )
 | |
|         assert num_tokens == cross_attention_masks.shape[1], (
 | |
|             f"Mismatch in text sequence length and cross attention mask sequence length {num_tokens} {cross_attention_masks.shape}"
 | |
|         )
 | |
|         _, _, _, num_image_tokens, image_token_dim = tuple(vision_tokens.shape)
 | |
|         bsz, ntext, nimg, nchunks = cross_attention_masks.shape
 | |
|         cross_attention_masks = (
 | |
|             cross_attention_masks.repeat_interleave(vision_seqlen, dim=3).view(bsz, ntext, -1).unsqueeze(1)
 | |
|         )
 | |
|         full_text_row_masked_out_mask = _get_full_row_masked_out_mask(
 | |
|             cross_attention_masks,
 | |
|             get_negative_inf_value(cross_attention_masks.dtype),
 | |
|         )
 | |
|         cross_attention_masks *= full_text_row_masked_out_mask
 | |
| 
 | |
|         return (
 | |
|             cross_attention_masks.to(device=text_device, dtype=text_dtype),
 | |
|             full_text_row_masked_out_mask.to(device=text_device),
 | |
|         )
 | |
| 
 | |
| 
 | |
| class CrossAttentionTransformer(torch.nn.Module):
 | |
|     def __init__(self, args: ModelArgs) -> None:
 | |
|         super().__init__()
 | |
|         self.params = args
 | |
| 
 | |
|         self.model_dim = args.dim
 | |
|         self.vision_model = CrossAttentionTransformerVision(args)
 | |
|         self.text_model = CrossAttentionTransformerText(args)
 | |
|         self.image_res = args.vision_chunk_size
 | |
|         self.max_num_chunks = args.vision_max_num_chunks
 | |
|         self.image_transform = partial(
 | |
|             VariableSizeImageTransform(size=args.vision_chunk_size),
 | |
|             max_num_chunks=args.vision_max_num_chunks,
 | |
|         )
 | |
| 
 | |
|     def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
 | |
|         self.text_model.setup_cache(max_batch_size, device, dtype)
 | |
| 
 | |
|     def compute_vision_tokens_masks(
 | |
|         self,
 | |
|         batch_images: list[list[PIL_Image.Image]],
 | |
|         batch_masks: list[list[list[int]]],
 | |
|         total_len: int,
 | |
|         device: torch.device,
 | |
|     ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | |
|         skip_vision_encoder = False
 | |
| 
 | |
|         assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
 | |
| 
 | |
|         max_num_images = max(len(x) for x in batch_images)
 | |
|         bsz = len(batch_images)
 | |
| 
 | |
|         if max_num_images == 0:
 | |
|             num_chunks = [[self.max_num_chunks] for _ in batch_images]
 | |
|             skip_vision_encoder = True
 | |
|         else:
 | |
|             images_and_aspect_ratios = [[self.image_transform(im) for im in row] for row in batch_images]
 | |
|             transformed_images = [[x[0] for x in row] for row in images_and_aspect_ratios]
 | |
| 
 | |
|             aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64)
 | |
|             for i, row in enumerate(images_and_aspect_ratios):
 | |
|                 if len(row) > 0:
 | |
|                     aspect_ratios[i, : len(row)] = torch.stack([torch.tensor(x[1]) for x in row])
 | |
| 
 | |
|             stacked_images, num_chunks = _stack_images(
 | |
|                 transformed_images,
 | |
|                 max_num_chunks=self.max_num_chunks,
 | |
|                 image_res=self.params.vision_chunk_size,
 | |
|                 max_num_images=max_num_images,
 | |
|             )
 | |
|             stacked_images = stacked_images.to(device=device)
 | |
| 
 | |
|         if skip_vision_encoder:
 | |
|             vision_tokens = torch.zeros(
 | |
|                 (
 | |
|                     bsz,
 | |
|                     max_num_images,
 | |
|                     self.max_num_chunks,
 | |
|                     int((self.vision_model.image_res / self.vision_model.patch_size) ** 2 + 1),
 | |
|                     self.model_dim,
 | |
|                 ),
 | |
|             )
 | |
|         else:
 | |
|             vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
 | |
| 
 | |
|         bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
 | |
|         xattn_caches = torch.stack(
 | |
|             [
 | |
|                 layer.compute_xattn_kv_cache(vision_tokens.view(bsz, -1, image_token_dim))
 | |
|                 for layer in self.text_model.cross_attention_layers
 | |
|             ]
 | |
|         )
 | |
|         padded_masks = _pad_masks(
 | |
|             batch_masks,
 | |
|             num_chunks,
 | |
|             total_len,
 | |
|             self.max_num_chunks,
 | |
|         )
 | |
| 
 | |
|         cross_attention_masks, full_text_row_masked_out_mask = self.text_model._get_xattn_mask(
 | |
|             num_tokens=total_len,
 | |
|             text_device=vision_tokens.device.type,
 | |
|             text_dtype=next(self.text_model.parameters()).dtype,
 | |
|             vision_tokens=vision_tokens,
 | |
|             cross_attention_masks=padded_masks,
 | |
|         )
 | |
| 
 | |
|         return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         position_ids: torch.Tensor,
 | |
|         tokens: torch.Tensor,
 | |
|         cross_attention_masks: torch.Tensor,
 | |
|         full_text_row_masked_out_mask: torch.Tensor,
 | |
|         xattn_caches: torch.Tensor,
 | |
|         text_only_inference: bool = False,
 | |
|     ) -> torch.Tensor:
 | |
|         h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids])
 | |
|         logits = self.text_model.forward(
 | |
|             position_ids=position_ids,
 | |
|             h=h,
 | |
|             xattn_mask=cross_attention_masks[:, :, position_ids],
 | |
|             full_text_row_masked_out_mask=full_text_row_masked_out_mask[:, :, position_ids],
 | |
|             xattn_caches=xattn_caches,
 | |
|             text_only_inference=text_only_inference,
 | |
|         )
 | |
|         return logits
 | |
| 
 | |
| 
 | |
| def _stack_images(
 | |
|     images: list[list[PIL_Image.Image]],
 | |
|     max_num_chunks: int,
 | |
|     image_res: int,
 | |
|     max_num_images: int,
 | |
| ) -> tuple[torch.Tensor, list[int]]:
 | |
|     """
 | |
|     Takes a list of list of images and stacks them into a tensor.
 | |
|     This function is needed since images can be of completely
 | |
|     different resolutions and aspect ratios.
 | |
|     """
 | |
|     out_images, out_num_chunks = [], []
 | |
|     for imgs_sample in images:
 | |
|         out_images_i = torch.zeros(
 | |
|             max_num_images,
 | |
|             max_num_chunks,
 | |
|             3,
 | |
|             image_res,
 | |
|             image_res,
 | |
|         )
 | |
|         _num_chunks = []
 | |
|         for j, chunks_image in enumerate(imgs_sample):
 | |
|             out_images_i[j, : chunks_image.shape[0]] = chunks_image
 | |
|             _num_chunks.append(chunks_image.shape[0])
 | |
|         out_images.append(out_images_i)
 | |
|         out_num_chunks.append(_num_chunks)
 | |
|     return torch.stack(out_images), out_num_chunks
 | |
| 
 | |
| 
 | |
| def _pad_masks(
 | |
|     all_masks: list[list[list[int]]],
 | |
|     all_num_chunks: list[list[int]],
 | |
|     total_len: int,
 | |
|     max_num_chunks: int,
 | |
| ) -> torch.Tensor:
 | |
|     dtype = torch.get_default_dtype()
 | |
|     inf_value = get_negative_inf_value(dtype)
 | |
| 
 | |
|     bsz = len(all_masks)
 | |
|     max_num_media = max([len(m) for m in all_masks])
 | |
| 
 | |
|     out_masks = torch.full(
 | |
|         (bsz, total_len, max_num_media, max_num_chunks),
 | |
|         inf_value,
 | |
|         dtype=dtype,
 | |
|     )
 | |
| 
 | |
|     for idx, (mask, num_chunks) in enumerate(zip(all_masks, all_num_chunks, strict=False)):
 | |
|         for mask_idx, (mask_elem, mask_num_chunks) in enumerate(zip(mask, num_chunks, strict=False)):
 | |
|             if len(mask_elem) == 2:
 | |
|                 mask_elem[1] = min(mask_elem[1], total_len)
 | |
|                 if mask_elem[1] == -1:
 | |
|                     mask_elem[1] = total_len
 | |
|                 out_masks[idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks].fill_(0.0)
 | |
| 
 | |
|     return out_masks
 |