mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chore: add mypy coverage to meta_llama3_multimodal
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
e3de28d872
2 changed files with 28 additions and 16 deletions
|
@ -10,17 +10,17 @@ from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init # type: ignore
|
||||||
import torch
|
import torch # type: ignore
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F # type: ignore
|
||||||
from fairscale.nn.model_parallel.layers import (
|
from fairscale.nn.model_parallel.layers import ( # type: ignore
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed import _functional_collectives as funcol
|
from torch.distributed import _functional_collectives as funcol # type: ignore
|
||||||
|
|
||||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||||
from .encoder_utils import (
|
from .encoder_utils import (
|
||||||
|
@ -324,7 +324,7 @@ class VisionEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_num_tiles: int,
|
max_num_tiles: int,
|
||||||
ckpt_path: str = None,
|
ckpt_path: str | None = None,
|
||||||
image_size: int = 224,
|
image_size: int = 224,
|
||||||
patch_size: int = 14,
|
patch_size: int = 14,
|
||||||
width: int = 1280,
|
width: int = 1280,
|
||||||
|
@ -395,11 +395,11 @@ class VisionEncoder(nn.Module):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
missing_keys: list[str] = None,
|
missing_keys: list[str] | None = None,
|
||||||
unexpected_keys: list[str] = None,
|
unexpected_keys: list[str] | None = None,
|
||||||
error_msgs: list[str] = None,
|
error_msgs: list[str] | None = None,
|
||||||
return_state_dict: bool = False,
|
return_state_dict: bool = False,
|
||||||
) -> None:
|
) -> dict[str, Any] | None:
|
||||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||||
if orig_pos_embed is not None:
|
if orig_pos_embed is not None:
|
||||||
new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size)
|
new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size)
|
||||||
|
@ -429,6 +429,7 @@ class VisionEncoder(nn.Module):
|
||||||
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
||||||
if return_state_dict:
|
if return_state_dict:
|
||||||
return state_dict
|
return state_dict
|
||||||
|
return None
|
||||||
|
|
||||||
def apply_positional_embedding(self, x, ar):
|
def apply_positional_embedding(self, x, ar):
|
||||||
# apply regular position embedding
|
# apply regular position embedding
|
||||||
|
@ -571,6 +572,10 @@ class Attention(nn.Module):
|
||||||
)
|
)
|
||||||
self.n_heads = args.n_heads
|
self.n_heads = args.n_heads
|
||||||
|
|
||||||
|
# Initialize cache buffers as None, they will be properly typed in setup_cache
|
||||||
|
self.key_cache: torch.Tensor | None = None
|
||||||
|
self.value_cache: torch.Tensor | None = None
|
||||||
|
|
||||||
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
||||||
cache_shape = (
|
cache_shape = (
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
@ -602,6 +607,9 @@ class Attention(nn.Module):
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
):
|
):
|
||||||
|
assert self.key_cache is not None and self.value_cache is not None, (
|
||||||
|
"Cache must be set up before calling forward"
|
||||||
|
)
|
||||||
self.key_cache = self.key_cache.to(x.device)
|
self.key_cache = self.key_cache.to(x.device)
|
||||||
self.value_cache = self.value_cache.to(x.device)
|
self.value_cache = self.value_cache.to(x.device)
|
||||||
|
|
||||||
|
@ -791,7 +799,7 @@ class TilePositionEmbedding(nn.Module):
|
||||||
embed_new = embed_new.permute(2, 3, 0, 1)
|
embed_new = embed_new.permute(2, 3, 0, 1)
|
||||||
return embed_new
|
return embed_new
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None):
|
def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int | None = None):
|
||||||
embed = self.embedding
|
embed = self.embedding
|
||||||
if num_tiles is None:
|
if num_tiles is None:
|
||||||
num_tiles = self.num_tiles
|
num_tiles = self.num_tiles
|
||||||
|
@ -1027,12 +1035,13 @@ class DummySelfAttentionTransformerBlock:
|
||||||
class CrossAttentionTransformerVision(torch.nn.Module):
|
class CrossAttentionTransformerVision(torch.nn.Module):
|
||||||
def __init__(self, args: ModelArgs) -> None:
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
return_intermediate = "3,7,15,23,30"
|
return_intermediate: str | list[int] | None = "3,7,15,23,30"
|
||||||
self.vision_input_dim = 1280
|
self.vision_input_dim = 1280
|
||||||
self.image_res = args.vision_chunk_size
|
self.image_res = args.vision_chunk_size
|
||||||
self.max_num_chunks = args.vision_max_num_chunks
|
self.max_num_chunks = args.vision_max_num_chunks
|
||||||
if return_intermediate is not None:
|
if return_intermediate is not None:
|
||||||
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
if isinstance(return_intermediate, str):
|
||||||
|
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
||||||
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
||||||
self.patch_size = 14
|
self.patch_size = 14
|
||||||
self.vision_encoder = VisionEncoder(
|
self.vision_encoder = VisionEncoder(
|
||||||
|
@ -1142,6 +1151,9 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
self.cache_is_setup = False
|
self.cache_is_setup = False
|
||||||
self.max_seq_len = args.max_seq_len
|
self.max_seq_len = args.max_seq_len
|
||||||
|
|
||||||
|
# Initialize mask_cache as None, it will be properly typed in setup_cache
|
||||||
|
self.mask_cache: torch.Tensor | None = None
|
||||||
|
|
||||||
def _init_fusion_schedule(
|
def _init_fusion_schedule(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
|
@ -1175,6 +1187,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
text_only_inference: bool = False,
|
text_only_inference: bool = False,
|
||||||
):
|
):
|
||||||
assert self.cache_is_setup, "Please set up cache before calling forward"
|
assert self.cache_is_setup, "Please set up cache before calling forward"
|
||||||
|
assert self.mask_cache is not None, "Mask cache must be set up before calling forward"
|
||||||
self.mask_cache = self.mask_cache.to(h.device)
|
self.mask_cache = self.mask_cache.to(h.device)
|
||||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||||
mask = self.mask_cache.index_select(2, position_ids)
|
mask = self.mask_cache.index_select(2, position_ids)
|
||||||
|
@ -1372,11 +1385,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _stack_images(
|
def _stack_images(
|
||||||
images: list[list[PIL_Image.Image]],
|
images: list[list[torch.Tensor]], # Changed type annotation from PIL_Image.Image to torch.Tensor
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
image_res: int,
|
image_res: int,
|
||||||
max_num_images: int,
|
max_num_images: int,
|
||||||
) -> tuple[torch.Tensor, list[int]]:
|
) -> tuple[torch.Tensor, list[list[int]]]:
|
||||||
"""
|
"""
|
||||||
Takes a list of list of images and stacks them into a tensor.
|
Takes a list of list of images and stacks them into a tensor.
|
||||||
This function is needed since images can be of completely
|
This function is needed since images can be of completely
|
||||||
|
|
|
@ -248,7 +248,6 @@ exclude = [
|
||||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
|
||||||
"^llama_stack/models/llama/llama4/",
|
"^llama_stack/models/llama/llama4/",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue