chore: add mypy coverage to meta_llama3_multimodal

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-10 11:50:54 +02:00
parent d880c2df0e
commit e3de28d872
2 changed files with 28 additions and 16 deletions

View file

@ -10,17 +10,17 @@ from collections.abc import Callable
from functools import partial from functools import partial
from typing import Any from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init import fairscale.nn.model_parallel.initialize as fs_init # type: ignore
import torch import torch # type: ignore
import torch.nn.functional as F import torch.nn.functional as F # type: ignore
from fairscale.nn.model_parallel.layers import ( from fairscale.nn.model_parallel.layers import ( # type: ignore
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed import _functional_collectives as funcol from torch.distributed import _functional_collectives as funcol # type: ignore
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
from .encoder_utils import ( from .encoder_utils import (
@ -324,7 +324,7 @@ class VisionEncoder(nn.Module):
def __init__( def __init__(
self, self,
max_num_tiles: int, max_num_tiles: int,
ckpt_path: str = None, ckpt_path: str | None = None,
image_size: int = 224, image_size: int = 224,
patch_size: int = 14, patch_size: int = 14,
width: int = 1280, width: int = 1280,
@ -395,11 +395,11 @@ class VisionEncoder(nn.Module):
prefix: str, prefix: str,
local_metadata: dict[str, Any], local_metadata: dict[str, Any],
strict: bool = True, strict: bool = True,
missing_keys: list[str] = None, missing_keys: list[str] | None = None,
unexpected_keys: list[str] = None, unexpected_keys: list[str] | None = None,
error_msgs: list[str] = None, error_msgs: list[str] | None = None,
return_state_dict: bool = False, return_state_dict: bool = False,
) -> None: ) -> dict[str, Any] | None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding") orig_pos_embed = state_dict.get(prefix + "positional_embedding")
if orig_pos_embed is not None: if orig_pos_embed is not None:
new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size) new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size)
@ -429,6 +429,7 @@ class VisionEncoder(nn.Module):
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
if return_state_dict: if return_state_dict:
return state_dict return state_dict
return None
def apply_positional_embedding(self, x, ar): def apply_positional_embedding(self, x, ar):
# apply regular position embedding # apply regular position embedding
@ -571,6 +572,10 @@ class Attention(nn.Module):
) )
self.n_heads = args.n_heads self.n_heads = args.n_heads
# Initialize cache buffers as None, they will be properly typed in setup_cache
self.key_cache: torch.Tensor | None = None
self.value_cache: torch.Tensor | None = None
def setup_cache(self, max_batch_size: int, dtype: torch.dtype): def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
cache_shape = ( cache_shape = (
max_batch_size, max_batch_size,
@ -602,6 +607,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
position_ids: torch.LongTensor, position_ids: torch.LongTensor,
): ):
assert self.key_cache is not None and self.value_cache is not None, (
"Cache must be set up before calling forward"
)
self.key_cache = self.key_cache.to(x.device) self.key_cache = self.key_cache.to(x.device)
self.value_cache = self.value_cache.to(x.device) self.value_cache = self.value_cache.to(x.device)
@ -791,7 +799,7 @@ class TilePositionEmbedding(nn.Module):
embed_new = embed_new.permute(2, 3, 0, 1) embed_new = embed_new.permute(2, 3, 0, 1)
return embed_new return embed_new
def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None): def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int | None = None):
embed = self.embedding embed = self.embedding
if num_tiles is None: if num_tiles is None:
num_tiles = self.num_tiles num_tiles = self.num_tiles
@ -1027,12 +1035,13 @@ class DummySelfAttentionTransformerBlock:
class CrossAttentionTransformerVision(torch.nn.Module): class CrossAttentionTransformerVision(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None: def __init__(self, args: ModelArgs) -> None:
super().__init__() super().__init__()
return_intermediate = "3,7,15,23,30" return_intermediate: str | list[int] | None = "3,7,15,23,30"
self.vision_input_dim = 1280 self.vision_input_dim = 1280
self.image_res = args.vision_chunk_size self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None: if return_intermediate is not None:
return_intermediate = [int(layer) for layer in return_intermediate.split(",")] if isinstance(return_intermediate, str):
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14 self.patch_size = 14
self.vision_encoder = VisionEncoder( self.vision_encoder = VisionEncoder(
@ -1142,6 +1151,9 @@ class CrossAttentionTransformerText(torch.nn.Module):
self.cache_is_setup = False self.cache_is_setup = False
self.max_seq_len = args.max_seq_len self.max_seq_len = args.max_seq_len
# Initialize mask_cache as None, it will be properly typed in setup_cache
self.mask_cache: torch.Tensor | None = None
def _init_fusion_schedule( def _init_fusion_schedule(
self, self,
num_layers: int, num_layers: int,
@ -1175,6 +1187,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False, text_only_inference: bool = False,
): ):
assert self.cache_is_setup, "Please set up cache before calling forward" assert self.cache_is_setup, "Please set up cache before calling forward"
assert self.mask_cache is not None, "Mask cache must be set up before calling forward"
self.mask_cache = self.mask_cache.to(h.device) self.mask_cache = self.mask_cache.to(h.device)
self.freqs_cis = self.freqs_cis.to(h.device) self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids) mask = self.mask_cache.index_select(2, position_ids)
@ -1372,11 +1385,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def _stack_images( def _stack_images(
images: list[list[PIL_Image.Image]], images: list[list[torch.Tensor]], # Changed type annotation from PIL_Image.Image to torch.Tensor
max_num_chunks: int, max_num_chunks: int,
image_res: int, image_res: int,
max_num_images: int, max_num_images: int,
) -> tuple[torch.Tensor, list[int]]: ) -> tuple[torch.Tensor, list[list[int]]]:
""" """
Takes a list of list of images and stacks them into a tensor. Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely This function is needed since images can be of completely

View file

@ -248,7 +248,6 @@ exclude = [
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",