From 257eaeb945e92ea14140f99f19d9c1bd745ea6f3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 27 Oct 2025 22:35:24 -0700 Subject: [PATCH] fix(mypy): resolve model implementation typing issues Fixes type errors across 4 model implementation files (Phase 2d): - image_transform.py: Fix return type annotations and variable shadowing - Changed find_supported_resolutions return type from Tensor to list[tuple[int, int]] - Changed resize_without_distortion parameter/return from Tensor to Image.Image - Use separate variable names for list vs tensor (possible_resolutions_list/tensor) - checkpoint.py: Replace deprecated torch.BFloat16Tensor usage - Use torch.set_default_dtype(torch.bfloat16) instead of deprecated tensor types - Rename ckpt_paths to ckpt_paths_array to avoid variable shadowing - hadamard_utils.py: Add type assertion for nn.Linear - Assert isinstance check to narrow type from Module to Linear before accessing in_features - encoder_utils.py: Fix variable shadowing - Use masks_list for list accumulation, masks for final Tensor result Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/llama_stack/models/llama/checkpoint.py | 12 ++++++------ src/llama_stack/models/llama/hadamard_utils.py | 2 ++ .../models/llama/llama3/multimodal/encoder_utils.py | 6 +++--- .../llama/llama3/multimodal/image_transform.py | 12 ++++++------ 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/llama_stack/models/llama/checkpoint.py b/src/llama_stack/models/llama/checkpoint.py index c9e0030e3..b00e2ed18 100644 --- a/src/llama_stack/models/llama/checkpoint.py +++ b/src/llama_stack/models/llama/checkpoint.py @@ -38,18 +38,18 @@ def maybe_reshard_state_dict( mmap: bool = True, ) -> dict[str, torch.Tensor]: if str(map_location) == "cpu": - torch.set_default_tensor_type(torch.BFloat16Tensor) + torch.set_default_dtype(torch.bfloat16) else: - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + torch.set_default_dtype(torch.bfloat16) - ckpt_paths = np.array(sorted(ckpt_paths)) + ckpt_paths_array = np.array(sorted(ckpt_paths)) new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank() - old_mp_size = len(ckpt_paths) + old_mp_size = len(ckpt_paths_array) old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank) - print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore - paths = ckpt_paths[old_mp_ranks] # type: ignore + print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}") # type: ignore + paths = ckpt_paths_array[old_mp_ranks] # type: ignore state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths] if new_mp_size == old_mp_size: diff --git a/src/llama_stack/models/llama/hadamard_utils.py b/src/llama_stack/models/llama/hadamard_utils.py index 87f3829d0..02b569aaf 100644 --- a/src/llama_stack/models/llama/hadamard_utils.py +++ b/src/llama_stack/models/llama/hadamard_utils.py @@ -79,6 +79,8 @@ def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = " for module_name, module in model.named_children(): child_full_name = prefix + "." + module_name if re.search(pattern_last_linear_ffn, child_full_name): + # Module matching this pattern should be nn.Linear with in_features + assert isinstance(module, nn.Linear), f"Expected nn.Linear, got {type(module)}" new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module) del module setattr(model, module_name, new_module) diff --git a/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 90ced13b2..0cc5aec81 100644 --- a/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -141,15 +141,15 @@ def build_encoder_attention_mask( """ Build vision encoder attention mask that omits padding tokens. """ - masks = [] + masks_list = [] for arx in ar: mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype) mask_i[: arx[0] * arx[1], :ntok] = 0 mask_i = mask_i.view(num_chunks * x.shape[2], -1) mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype) mask_i = mask_i.unsqueeze(0) - masks.append(mask_i) - masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1) + masks_list.append(mask_i) + masks = torch.stack(masks_list).to(x.device).expand(-1, n_heads, -1, -1) return masks diff --git a/src/llama_stack/models/llama/llama3/multimodal/image_transform.py b/src/llama_stack/models/llama/llama3/multimodal/image_transform.py index 7b20a31fa..de2709c74 100644 --- a/src/llama_stack/models/llama/llama3/multimodal/image_transform.py +++ b/src/llama_stack/models/llama/llama3/multimodal/image_transform.py @@ -95,7 +95,7 @@ class VariableSizeImageTransform: factors_set.add(n // i) return factors_set - def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor: + def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> list[tuple[int, int]]: """ Computes all of the allowed resoltuions for a fixed number of chunks and patch_size. Useful for when dividing an image into chunks. @@ -198,10 +198,10 @@ class VariableSizeImageTransform: def resize_without_distortion( self, - image: torch.Tensor, + image: Image.Image, target_size: tuple[int, int], max_upscaling_size: int | None, - ) -> torch.Tensor: + ) -> Image.Image: """ Used to resize an image to target_resolution, without distortion. @@ -380,12 +380,12 @@ class VariableSizeImageTransform: assert isinstance(image, Image.Image), type(image) w, h = image.size - possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size) - possible_resolutions = torch.tensor(possible_resolutions) + possible_resolutions_list = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size) + possible_resolutions_tensor = torch.tensor(possible_resolutions_list) best_resolution = self.get_best_fit( image_size=(w, h), - possible_resolutions=possible_resolutions, + possible_resolutions=possible_resolutions_tensor, resize_to_max_canvas=resize_to_max_canvas, )