From e3de28d872eed978d96fcc5520726bfa0be4fc5e Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Thu, 10 Jul 2025 11:50:54 +0200 Subject: [PATCH] chore: add mypy coverage to meta_llama3_multimodal Signed-off-by: Mustafa Elbehery --- .../models/llama/llama3/multimodal/model.py | 43 ++++++++++++------- pyproject.toml | 1 - 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 5f1c3605c..9b9ef9c8d 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -10,17 +10,17 @@ from collections.abc import Callable from functools import partial from typing import Any -import fairscale.nn.model_parallel.initialize as fs_init -import torch -import torch.nn.functional as F -from fairscale.nn.model_parallel.layers import ( +import fairscale.nn.model_parallel.initialize as fs_init # type: ignore +import torch # type: ignore +import torch.nn.functional as F # type: ignore +from fairscale.nn.model_parallel.layers import ( # type: ignore ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ) from PIL import Image as PIL_Image from torch import Tensor, nn -from torch.distributed import _functional_collectives as funcol +from torch.distributed import _functional_collectives as funcol # type: ignore from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from .encoder_utils import ( @@ -324,7 +324,7 @@ class VisionEncoder(nn.Module): def __init__( self, max_num_tiles: int, - ckpt_path: str = None, + ckpt_path: str | None = None, image_size: int = 224, patch_size: int = 14, width: int = 1280, @@ -395,11 +395,11 @@ class VisionEncoder(nn.Module): prefix: str, local_metadata: dict[str, Any], strict: bool = True, - missing_keys: list[str] = None, - unexpected_keys: list[str] = None, - error_msgs: list[str] = None, + missing_keys: list[str] | None = None, + unexpected_keys: list[str] | None = None, + error_msgs: list[str] | None = None, return_state_dict: bool = False, - ) -> None: + ) -> dict[str, Any] | None: orig_pos_embed = state_dict.get(prefix + "positional_embedding") if orig_pos_embed is not None: new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size) @@ -429,6 +429,7 @@ class VisionEncoder(nn.Module): state_dict[prefix + "gated_positional_embedding"] = global_pos_embed if return_state_dict: return state_dict + return None def apply_positional_embedding(self, x, ar): # apply regular position embedding @@ -571,6 +572,10 @@ class Attention(nn.Module): ) self.n_heads = args.n_heads + # Initialize cache buffers as None, they will be properly typed in setup_cache + self.key_cache: torch.Tensor | None = None + self.value_cache: torch.Tensor | None = None + def setup_cache(self, max_batch_size: int, dtype: torch.dtype): cache_shape = ( max_batch_size, @@ -602,6 +607,9 @@ class Attention(nn.Module): freqs_cis: torch.Tensor, position_ids: torch.LongTensor, ): + assert self.key_cache is not None and self.value_cache is not None, ( + "Cache must be set up before calling forward" + ) self.key_cache = self.key_cache.to(x.device) self.value_cache = self.value_cache.to(x.device) @@ -791,7 +799,7 @@ class TilePositionEmbedding(nn.Module): embed_new = embed_new.permute(2, 3, 0, 1) return embed_new - def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None): + def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int | None = None): embed = self.embedding if num_tiles is None: num_tiles = self.num_tiles @@ -1027,12 +1035,13 @@ class DummySelfAttentionTransformerBlock: class CrossAttentionTransformerVision(torch.nn.Module): def __init__(self, args: ModelArgs) -> None: super().__init__() - return_intermediate = "3,7,15,23,30" + return_intermediate: str | list[int] | None = "3,7,15,23,30" self.vision_input_dim = 1280 self.image_res = args.vision_chunk_size self.max_num_chunks = args.vision_max_num_chunks if return_intermediate is not None: - return_intermediate = [int(layer) for layer in return_intermediate.split(",")] + if isinstance(return_intermediate, str): + return_intermediate = [int(layer) for layer in return_intermediate.split(",")] self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim self.patch_size = 14 self.vision_encoder = VisionEncoder( @@ -1142,6 +1151,9 @@ class CrossAttentionTransformerText(torch.nn.Module): self.cache_is_setup = False self.max_seq_len = args.max_seq_len + # Initialize mask_cache as None, it will be properly typed in setup_cache + self.mask_cache: torch.Tensor | None = None + def _init_fusion_schedule( self, num_layers: int, @@ -1175,6 +1187,7 @@ class CrossAttentionTransformerText(torch.nn.Module): text_only_inference: bool = False, ): assert self.cache_is_setup, "Please set up cache before calling forward" + assert self.mask_cache is not None, "Mask cache must be set up before calling forward" self.mask_cache = self.mask_cache.to(h.device) self.freqs_cis = self.freqs_cis.to(h.device) mask = self.mask_cache.index_select(2, position_ids) @@ -1372,11 +1385,11 @@ class CrossAttentionTransformer(torch.nn.Module): def _stack_images( - images: list[list[PIL_Image.Image]], + images: list[list[torch.Tensor]], # Changed type annotation from PIL_Image.Image to torch.Tensor max_num_chunks: int, image_res: int, max_num_images: int, -) -> tuple[torch.Tensor, list[int]]: +) -> tuple[torch.Tensor, list[list[int]]]: """ Takes a list of list of images and stacks them into a tensor. This function is needed since images can be of completely diff --git a/pyproject.toml b/pyproject.toml index d84a823a3..5bad8ae2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,7 +248,6 @@ exclude = [ "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$", - "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",