mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 17:11:12 +00:00 
			
		
		
		
	# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			412 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			412 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from collections.abc import Callable
 | |
| from typing import Any
 | |
| 
 | |
| import fairscale.nn.model_parallel.initialize as fs_init
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
 | |
| from torch import einsum
 | |
| 
 | |
| from ..args import ModelArgs
 | |
| from ..model import Attention
 | |
| 
 | |
| 
 | |
| class LayerNorm(nn.LayerNorm):
 | |
|     """Subclass torch's LayerNorm to handle fp16."""
 | |
| 
 | |
|     def forward(self, x: torch.Tensor):
 | |
|         x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class ColumnParallelConv2dPatch(torch.nn.Module):
 | |
|     """Conv2D Patching layer with model parallelism.
 | |
|     Column parallel over unfolded input.
 | |
|     Arguments:
 | |
|         in_channels: Input channels.
 | |
|         out_channels: Output channels.
 | |
|         kernel_size: Size of convolution kernel.
 | |
|         stride (default 1): Stride for convolution.
 | |
|         bias (default False): Use bias in Conv2d.
 | |
|     Input: (bsz, in_channels, height, width)
 | |
|     Output: (bsz, num_tokens, out_channels)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels: int,
 | |
|         out_channels: int,
 | |
|         kernel_size: int | tuple[int, int],
 | |
|         stride: int | tuple[int, int],
 | |
|         bias: bool | None = False,
 | |
|     ) -> None:
 | |
|         super().__init__()
 | |
|         if isinstance(kernel_size, int):
 | |
|             kernel_size = (kernel_size, kernel_size)
 | |
|         self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
 | |
|         self._linear = ColumnParallelLinear(
 | |
|             in_channels * kernel_size[0] * kernel_size[1],
 | |
|             out_channels,
 | |
|             bias=bias,
 | |
|         )
 | |
| 
 | |
|     def forward(self, x: torch.Tensor) -> torch.Tensor:
 | |
|         x = self._unfold(x)
 | |
|         x = x.permute(0, 2, 1)
 | |
|         x = self._linear(x)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class _FeedForward(torch.nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,
 | |
|         hidden_dim: int,
 | |
|         dropout: float,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         # layers
 | |
|         self.c_fc = ColumnParallelLinear(
 | |
|             dim,
 | |
|             hidden_dim,
 | |
|             bias=True,
 | |
|             gather_output=False,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.c_proj = RowParallelLinear(
 | |
|             hidden_dim,
 | |
|             dim,
 | |
|             bias=True,
 | |
|             input_is_parallel=True,
 | |
|             init_method=lambda x: x,
 | |
|         )
 | |
|         self.non_linearity = act_layer()
 | |
|         self.dropout = dropout
 | |
| 
 | |
|     def forward(self, x):
 | |
|         hidden = self.c_fc(x)
 | |
|         hidden = self.non_linearity(hidden)
 | |
|         hidden = F.dropout(hidden, p=self.dropout, training=self.training)
 | |
|         return self.c_proj(hidden)
 | |
| 
 | |
| 
 | |
| class _TransformerBlock(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         d_model: int,
 | |
|         n_head: int,
 | |
|         mlp_ratio: float = 4.0,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|         gated: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         assert d_model % n_head == 0
 | |
|         self.n_heads = n_head
 | |
|         self.head_dim = d_model // self.n_heads
 | |
| 
 | |
|         attn_args = ModelArgs(
 | |
|             dim=d_model,
 | |
|             head_dim=self.head_dim,
 | |
|             n_heads=self.n_heads,
 | |
|             n_kv_heads=self.n_heads,
 | |
|         )
 | |
|         self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
 | |
|         self.ln_1 = LayerNorm(d_model)
 | |
|         self.mlp = _FeedForward(
 | |
|             dim=d_model,
 | |
|             hidden_dim=int(mlp_ratio * d_model),
 | |
|             dropout=0.0,
 | |
|             act_layer=act_layer,
 | |
|         )
 | |
|         self.ln_2 = LayerNorm(d_model)
 | |
|         self.gated = gated
 | |
|         if gated:
 | |
|             self.gate_attn = nn.Parameter(torch.zeros(1))
 | |
|             self.gate_ffn = nn.Parameter(torch.zeros(1))
 | |
| 
 | |
|     def attention(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         freq_cis: torch.Tensor | None = None,
 | |
|     ):
 | |
|         return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         mask: torch.Tensor | None = None,
 | |
|         freq_cis: torch.Tensor | None = None,
 | |
|     ):
 | |
|         _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
 | |
|         _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
 | |
| 
 | |
|         x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
 | |
|         x = x + _gate_ffn * self.mlp(self.ln_2(x))
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class _Transformer(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,
 | |
|         layers: int,
 | |
|         heads: int,
 | |
|         mlp_ratio: float = 4.0,
 | |
|         act_layer: Callable = nn.GELU,
 | |
|         gated: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.resblocks = nn.ModuleList(
 | |
|             [
 | |
|                 _TransformerBlock(
 | |
|                     d_model=dim,
 | |
|                     n_head=heads,
 | |
|                     mlp_ratio=mlp_ratio,
 | |
|                     act_layer=act_layer,
 | |
|                     gated=gated,
 | |
|                 )
 | |
|                 for _ in range(layers)
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|     def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
 | |
|         out = []
 | |
|         for idx, r in enumerate(self.resblocks):
 | |
|             if return_intermediate is not None and idx in return_intermediate:
 | |
|                 out.append(x)
 | |
|             x = r(x, mask=mask, freq_cis=freq_cis)
 | |
|         if return_intermediate is not None:
 | |
|             return x, torch.stack(out, dim=-1)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class PackingIndex:
 | |
|     Z = 0  # Z (time) coordinate of the token in the original sample
 | |
|     Y = 1  # Y (height) coordinate of the token in the original sample
 | |
|     X = 2  # X (width) coordinate of the token in the original sample
 | |
|     TIME = 3  # Total number of time units (frames) in the original sample
 | |
|     HEIGHT = 4  # Height of the original sample
 | |
|     WIDTH = 5  # Width of the original sample
 | |
|     # USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
 | |
|     IDX = 6  # Full index of the token in the original sample (x + y * w + z * w * h)
 | |
|     BATCH_IDX = 7  # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
 | |
| 
 | |
|     # Total size of the enum, remember to update this!
 | |
|     NUM_METADATA = 8
 | |
| 
 | |
|     # Note: For padding tokens IDX = -1
 | |
|     #       For cls tokens,    IDX = -2
 | |
|     ID_CLS_TOKEN = -2
 | |
|     ID_PAD_TOKEN = -1
 | |
| 
 | |
| 
 | |
| class VisionEncoder(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         image_size: tuple[int, int],
 | |
|         patch_size: tuple[int, int],
 | |
|         dim: int,
 | |
|         layers: int,
 | |
|         heads: int,
 | |
|         mlp_ratio: float,
 | |
|         in_channels: int = 3,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.image_size = image_size
 | |
|         self.patch_size = patch_size
 | |
|         self.grid_size = (
 | |
|             self.image_size[0] // self.patch_size[0],
 | |
|             self.image_size[1] // self.patch_size[1],
 | |
|         )
 | |
|         self.conv1 = ColumnParallelConv2dPatch(
 | |
|             in_channels=in_channels,
 | |
|             out_channels=dim,
 | |
|             kernel_size=patch_size,
 | |
|             stride=patch_size,
 | |
|             bias=False,
 | |
|         )
 | |
|         scale = dim**-0.5
 | |
|         self.class_embedding = nn.Parameter(scale * torch.randn(dim))
 | |
| 
 | |
|         self.positional_embedding_vlm = nn.Parameter(
 | |
|             scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
 | |
|         )
 | |
| 
 | |
|         self.ln_pre = LayerNorm(dim)
 | |
|         self.ln_post = LayerNorm(dim)
 | |
|         self.transformer = _Transformer(
 | |
|             dim,
 | |
|             layers,
 | |
|             heads,
 | |
|             mlp_ratio,
 | |
|             act_layer=nn.GELU,
 | |
|         )
 | |
| 
 | |
|         # NOTE: hack for the fixed res
 | |
|         image_h, image_w = self.image_size
 | |
|         patch_h, patch_w = self.patch_size
 | |
|         idx_h, idx_w = image_h // patch_h, image_w // patch_w
 | |
|         img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
 | |
|         img_idx = img_idx.reshape(idx_h * idx_w, 1)
 | |
|         img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
 | |
|         img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
 | |
| 
 | |
|         packed_img_idx = torch.empty(
 | |
|             img_idx.shape[0],
 | |
|             img_idx.shape[1],
 | |
|             PackingIndex.NUM_METADATA - 1,
 | |
|             dtype=torch.int32,
 | |
|         )
 | |
|         packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
 | |
|         packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
 | |
|         packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
 | |
|         packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
 | |
|         packed_img_idx[:, :, PackingIndex.IDX] = img_idx
 | |
|         packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
 | |
|         self.packed_img_idx = packed_img_idx  # for positional embedding load hook
 | |
| 
 | |
|         # compute rope freqs
 | |
|         rope_freq = self.get_rope_freqs(dim // heads // 2)
 | |
|         freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
 | |
|         freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
 | |
|         freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
 | |
|         # disable RoPE for padding and cls tokens
 | |
|         freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
 | |
|         # compute complex freqs
 | |
|         self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
 | |
|         # xlf automatically broadcasts
 | |
|         self.freq_cis = self.freq_cis.squeeze(0)
 | |
|         self.n_heads = heads // fs_init.get_model_parallel_world_size()
 | |
| 
 | |
|         self._register_load_state_dict_pre_hook(self.load_hook)
 | |
| 
 | |
|     def get_rope_freqs(self, dim, theta=10000):
 | |
|         freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
 | |
|         return freqs
 | |
| 
 | |
|     @torch.amp.autocast("cuda", enabled=False)
 | |
|     def compute_rope_freqs(self, freqs, t):
 | |
|         freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
 | |
|         freqs = freqs.repeat_interleave(2, dim=-1)
 | |
|         return freqs
 | |
| 
 | |
|     def load_hook(
 | |
|         self,
 | |
|         state_dict: dict[str, Any],
 | |
|         prefix: str,
 | |
|         local_metadata: dict[str, Any],
 | |
|         strict: bool = True,
 | |
|         missing_keys: list[str] = None,
 | |
|         unexpected_keys: list[str] = None,
 | |
|         error_msgs: list[str] = None,
 | |
|         return_state_dict: bool = False,
 | |
|     ) -> None:
 | |
|         orig_pos_embed = state_dict.get(prefix + "positional_embedding")
 | |
|         if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
 | |
|             raise ValueError(
 | |
|                 f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
 | |
|             )
 | |
| 
 | |
|         batch_size, token_per_image, _ = self.packed_img_idx.shape
 | |
|         # Input points for idx are [x, y, w, h]
 | |
|         idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
 | |
|         total_windows, window_size, _ = idx.shape
 | |
| 
 | |
|         # Grid values are [-1, 1] and coords are w, h
 | |
|         grid = (
 | |
|             (idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
 | |
|         )[None, ...]
 | |
| 
 | |
|         # In this mode, cls token has no position embedding
 | |
|         if orig_pos_embed is not None:
 | |
|             posemb = (
 | |
|                 orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
 | |
|             )
 | |
|             posemb = posemb.to(device=grid.device, dtype=grid.dtype)
 | |
|             sample = F.grid_sample(
 | |
|                 posemb, grid, padding_mode="zeros"
 | |
|             )  # padding tokens / class token will get zero for posemb
 | |
|             sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
 | |
|             sample = torch.where(
 | |
|                 idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
 | |
|                 orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
 | |
|                 sample,
 | |
|             )
 | |
| 
 | |
|             new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
 | |
| 
 | |
|             state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
 | |
| 
 | |
|         if return_state_dict:
 | |
|             return state_dict
 | |
| 
 | |
|     def apply_class_embedding(self, x):
 | |
|         x = torch.cat(
 | |
|             [
 | |
|                 x,
 | |
|                 self.class_embedding.to(x.dtype)
 | |
|                 + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
 | |
|             ],
 | |
|             dim=1,
 | |
|         )  # shape = [*, grid ** 2 + 1, width]
 | |
|         return x
 | |
| 
 | |
|     def forward(self, images: torch.Tensor) -> torch.Tensor:
 | |
|         # NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
 | |
|         if images.ndim == 5:
 | |
|             num_concurrent_media = 1
 | |
|             bsz, num_chunks, nch, h, w = images.shape
 | |
|         else:
 | |
|             bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
 | |
| 
 | |
|         images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
 | |
|         # patch embedding
 | |
|         x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
 | |
|         x = self.conv1(x)  # shape = [*, width, grid ** 2]
 | |
|         _, ntok, dim = x.shape
 | |
|         x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
 | |
| 
 | |
|         # apply cls token
 | |
|         x = self.apply_class_embedding(x)
 | |
|         ntok += 1
 | |
| 
 | |
|         # apply position embeddings
 | |
|         if self.positional_embedding_vlm is not None:
 | |
|             x = x + self.positional_embedding_vlm.to(x.dtype)
 | |
| 
 | |
|         x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
 | |
| 
 | |
|         x = self.ln_pre(x)
 | |
|         x = x.view(bsz * num_concurrent_media, -1, dim)
 | |
|         freq_cis = self.freq_cis.to(images.device)
 | |
| 
 | |
|         tf_output = self.transformer(
 | |
|             x,
 | |
|             freq_cis=freq_cis,
 | |
|         )
 | |
| 
 | |
|         int_x = None
 | |
|         if isinstance(tf_output, tuple):
 | |
|             x, int_x = tf_output
 | |
|         else:
 | |
|             x = tf_output
 | |
|         x = self.ln_post(x)
 | |
| 
 | |
|         # remove cls token output
 | |
|         x = x[:, :-1, :]
 | |
| 
 | |
|         # add and output x + int_x features
 | |
|         if int_x is not None:
 | |
|             int_x = int_x[:, :-1, :, :]
 | |
|             int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
 | |
|             x = torch.cat([x, int_x], dim=-1)
 | |
| 
 | |
|         return x
 |