chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -6,8 +6,9 @@
import logging
import math
from collections.abc import Callable
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int],
bias: bool | None = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
"""
Initialize the FeedForward module.
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
self,
x: torch.Tensor,
xattn_mask: torch.Tensor,
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
xattn_cache: torch.Tensor,
) -> torch.Tensor:
_attn_out = self.attention(
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
def _init_fusion_schedule(
self,
num_layers: int,
) -> List[int]:
) -> list[int]:
llama_layers = list(range(self.n_llama_layers))
# uniformly spread the layers
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_dtype,
vision_tokens,
cross_attention_masks,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
assert vision_tokens is not None, "Vision tokens must be provided"
vision_seqlen = vision_tokens.shape[3]
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
batch_images: list[list[PIL_Image.Image]],
batch_masks: list[list[list[int]]],
total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def _stack_images(
images: List[List[PIL_Image.Image]],
images: list[list[PIL_Image.Image]],
max_num_chunks: int,
image_res: int,
max_num_images: int,
) -> Tuple[torch.Tensor, List[int]]:
) -> tuple[torch.Tensor, list[int]]:
"""
Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely
@ -1400,8 +1401,8 @@ def _stack_images(
def _pad_masks(
all_masks: List[List[List[int]]],
all_num_chunks: List[List[int]],
all_masks: list[list[list[int]]],
all_num_chunks: list[list[int]],
total_len: int,
max_num_chunks: int,
) -> torch.Tensor: