mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix(mypy): resolve model implementation typing issues
Fixes type errors across 4 model implementation files (Phase 2d): - image_transform.py: Fix return type annotations and variable shadowing - Changed find_supported_resolutions return type from Tensor to list[tuple[int, int]] - Changed resize_without_distortion parameter/return from Tensor to Image.Image - Use separate variable names for list vs tensor (possible_resolutions_list/tensor) - checkpoint.py: Replace deprecated torch.BFloat16Tensor usage - Use torch.set_default_dtype(torch.bfloat16) instead of deprecated tensor types - Rename ckpt_paths to ckpt_paths_array to avoid variable shadowing - hadamard_utils.py: Add type assertion for nn.Linear - Assert isinstance check to narrow type from Module to Linear before accessing in_features - encoder_utils.py: Fix variable shadowing - Use masks_list for list accumulation, masks for final Tensor result Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f34b6288cc
commit
257eaeb945
4 changed files with 17 additions and 15 deletions
|
|
@ -38,18 +38,18 @@ def maybe_reshard_state_dict(
|
|||
mmap: bool = True,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if str(map_location) == "cpu":
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||
ckpt_paths_array = np.array(sorted(ckpt_paths))
|
||||
|
||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||
old_mp_size = len(ckpt_paths)
|
||||
old_mp_size = len(ckpt_paths_array)
|
||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths_array[old_mp_ranks] # type: ignore
|
||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||
|
||||
if new_mp_size == old_mp_size:
|
||||
|
|
|
|||
|
|
@ -79,6 +79,8 @@ def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "
|
|||
for module_name, module in model.named_children():
|
||||
child_full_name = prefix + "." + module_name
|
||||
if re.search(pattern_last_linear_ffn, child_full_name):
|
||||
# Module matching this pattern should be nn.Linear with in_features
|
||||
assert isinstance(module, nn.Linear), f"Expected nn.Linear, got {type(module)}"
|
||||
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
|
||||
del module
|
||||
setattr(model, module_name, new_module)
|
||||
|
|
|
|||
|
|
@ -141,15 +141,15 @@ def build_encoder_attention_mask(
|
|||
"""
|
||||
Build vision encoder attention mask that omits padding tokens.
|
||||
"""
|
||||
masks = []
|
||||
masks_list = []
|
||||
for arx in ar:
|
||||
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||
mask_i = mask_i.unsqueeze(0)
|
||||
masks.append(mask_i)
|
||||
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||
masks_list.append(mask_i)
|
||||
masks = torch.stack(masks_list).to(x.device).expand(-1, n_heads, -1, -1)
|
||||
return masks
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class VariableSizeImageTransform:
|
|||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
|
@ -198,10 +198,10 @@ class VariableSizeImageTransform:
|
|||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
image: Image.Image,
|
||||
target_size: tuple[int, int],
|
||||
max_upscaling_size: int | None,
|
||||
) -> torch.Tensor:
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
|
|
@ -380,12 +380,12 @@ class VariableSizeImageTransform:
|
|||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
possible_resolutions_list = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions_tensor = torch.tensor(possible_resolutions_list)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
possible_resolutions=possible_resolutions_tensor,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue