llama-stack-mirror/llama_stack/models/llama/llama3/multimodal/encoder_utils.py

179 lines
6.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
import torch
import torch.nn.functional as F
from .utils import get_negative_inf_value, to_2tuple
logger = getLogger()
def resize_local_position_embedding(orig_pos_embed, grid_size):
"""
Resize position embedding for vision encoder.
Original position embedding is [n_tiles * n_tiles + 1, dim]
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
"""
new_grid_size = to_2tuple(grid_size)
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
new_pos_emb_tok, new_pos_emb_img = (
orig_pos_embed[:1],
orig_pos_embed[1:],
)
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
return new_pos_embed
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a local position embedding for vision encoder and uses it
to initialize the global position embedding.
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
pos_embed = pos_and_cls_embed[1:]
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
grid_size = to_2tuple(grid_size)
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
return pos_and_cls_embed
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a global position embedding for vision encoder and resizes it to new size.
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
# first remove cls token
pos_embed = pos_and_cls_embed[:, :, 1:]
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
xs_old, ys_old, ntok, dim = pos_embed.shape
old_grid_size = int(math.sqrt(ntok))
# move to correct form for interpolation
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
pos_embed = pos_embed.unsqueeze(0)
# interpolate
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
pos_embed = pos_embed.permute(0, 3, 1, 2)
pos_embed_resized = F.interpolate(
pos_embed,
size=new_size,
mode="bilinear",
align_corners=True,
)
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
# move it back in place
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
# interpolate cls token
cls_embed = cls_embed.permute(2, 3, 0, 1)
cls_embed_resized = F.interpolate(
cls_embed,
size=(x_scale, y_scale),
mode="bilinear",
align_corners=True,
)
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
# add cls token back in
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
return pos_and_cls_embed
def build_encoder_attention_mask(
x: torch.Tensor,
ar: torch.Tensor,
ntok: int,
num_chunks: int,
n_heads: int,
):
"""
Build vision encoder attention mask that omits padding tokens.
"""
masks = []
for arx in ar:
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
mask_i[: arx[0] * arx[1], :ntok] = 0
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
mask_i = mask_i.unsqueeze(0)
masks.append(mask_i)
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
return masks
def expand_num_tokens_to_mult8(x):
num_pad_tokens = 8 - (x.shape[-2] % 8)
if num_pad_tokens == 0:
return x, 0
else:
return (
torch.cat(
[
x,
torch.zeros(
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
dtype=x.dtype,
device=x.device,
),
],
dim=-2,
),
num_pad_tokens,
)
def contract_num_tokens_from_mult8(x, num_pad_tokens):
if num_pad_tokens == 0:
return x
return x[:, :, :-num_pad_tokens]