chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel, model_validator
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
scheme: QuantizationScheme | None = None
group_size: int | None = None
spinquant: bool = False
@ -58,32 +57,32 @@ class ModelArgs(BaseModel):
dim: int = -1
n_layers: int = -1
n_heads: int = -1
n_kv_heads: Optional[int] = None
head_dim: Optional[int] = None
n_kv_heads: int | None = None
head_dim: int | None = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_exp: Optional[float] = None
ffn_dim_multiplier: float | None = None
ffn_exp: float | None = None
norm_eps: float = 1e-5
attention_chunk_size: Optional[int] = None
attention_chunk_size: int | None = None
rope_theta: float = 500000
use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None
rope_high_freq_factor: Optional[float] = None
rope_scaling_factor: float | None = None
rope_high_freq_factor: float | None = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
nope_layer_interval: int | None = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
attn_temperature_tuning: bool = False
floor_scale: float = 8192.0
attn_scale: float = 0.1
vision_args: Optional[VisionArgs] = None
moe_args: Optional[MoEArgs] = None
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
vision_args: VisionArgs | None = None
moe_args: MoEArgs | None = None
quantization_args: QuantizationArgs | None = None
lora_args: LoRAArgs | None = None
max_batch_size: int = 32
max_seq_len: int = 2048

View file

@ -8,7 +8,6 @@ import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
@ -46,10 +45,10 @@ def role_str(role: Role) -> str:
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: Tuple[int, int]
aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255
class ChatFormat:
possible_headers: Dict[Role, str]
possible_headers: dict[Role, str]
def __init__(
self,
tokenizer: Tokenizer,
vision_args: Optional[VisionArgs] = None,
vision_args: VisionArgs | None = None,
max_num_chunks: int = 16,
):
self.tokenizer = tokenizer
@ -81,7 +80,7 @@ class ChatFormat:
vision_args.image_size.width, vision_args.image_size.height
)
def _encode_header(self, role: str) -> List[int]:
def _encode_header(self, role: str) -> list[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
@ -98,7 +97,7 @@ class ChatFormat:
def _encode_image(
self,
transformed_image: TransformedImage,
) -> List[int]:
) -> list[int]:
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
@ -140,7 +139,7 @@ class ChatFormat:
return tokens
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]:
tokens = []
tranformed_images = []
@ -189,7 +188,7 @@ class ChatFormat:
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[TransformedImage]]:
) -> tuple[list[int], list[TransformedImage]]:
tokens = self._encode_header(message.role)
images = []
@ -223,7 +222,7 @@ class ChatFormat:
def encode_dialog_prompt(
self,
messages: List[RawMessage],
messages: list[RawMessage],
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> LLMInput:
tokens = []
@ -240,7 +239,7 @@ class ChatFormat:
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
@ -312,7 +311,7 @@ class ChatFormat:
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput:
return LLMInput(
tokens=tokens,
images=[x.image_tiles for x in images] if len(images) > 0 else None,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
@ -30,7 +29,7 @@ class LLMInput:
tokens: torch.Tensor
# images are already pre-processed (resized, tiled, etc.)
images: Optional[List[torch.Tensor]] = None
images: list[torch.Tensor] | None = None
@dataclass
@ -45,8 +44,8 @@ class TransformerInput:
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: Union[torch.Tensor, int]
image_embedding: Optional[MaskedEmbedding] = None
tokens_position: torch.Tensor | int
image_embedding: MaskedEmbedding | None = None
@dataclass

View file

@ -11,7 +11,7 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from typing import Any, Dict, List
from typing import Any
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
@ -36,13 +36,13 @@ class FeedForward(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)

View file

@ -10,8 +10,8 @@ import json
import os
import sys
import time
from collections.abc import Callable, Generator
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
@ -38,8 +38,8 @@ class Llama4:
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
world_size: int | None = None,
quantization_mode: QuantizationMode | None = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
@ -63,7 +63,7 @@ class Llama4:
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
with open(Path(ckpt_dir) / "params.json") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
@ -117,15 +117,15 @@ class Llama4:
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
llm_inputs: list[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Generator[list[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
@ -245,13 +245,13 @@ class Llama4:
def completion(
self,
contents: List[RawContent],
contents: list[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
llm_inputs=llm_inputs,
@ -267,13 +267,13 @@ class Llama4:
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
messages_batch: list[list[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
llm_inputs=llm_inputs,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -89,7 +89,7 @@ def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -174,13 +174,13 @@ class Attention(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
@ -200,7 +200,7 @@ class Attention(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor],
local_attn_mask: Optional[torch.Tensor],
global_attn_mask: torch.Tensor | None,
local_attn_mask: torch.Tensor | None,
):
# The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used
@ -374,13 +374,13 @@ class Transformer(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs")

View file

@ -6,7 +6,7 @@
# ruff: noqa: N806
# pyre-strict
from typing import Any, Dict, List
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -63,13 +63,13 @@ class Experts(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
self.prefix = prefix
if prefix + "moe_w_in_eD_F" in state_dict:
@ -158,13 +158,13 @@ class MoE(torch.nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "w_in_shared_FD.weight" in state_dict:
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
@ -210,5 +210,5 @@ class MoE(torch.nn.Module):
def divide_exact(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
return numerator // denominator

View file

@ -13,7 +13,6 @@
import math
from collections import defaultdict
from typing import Optional, Set, Tuple
import torch
import torchvision.transforms as tv
@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform:
return self.tv_transform(image)
class VariableSizeImageTransform(object):
class VariableSizeImageTransform:
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
@ -100,7 +99,7 @@ class VariableSizeImageTransform(object):
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
def get_factors(n: int) -> set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
@ -170,9 +169,9 @@ class VariableSizeImageTransform(object):
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
image_size: tuple[int, int],
target_size: tuple[int, int],
) -> tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
@ -223,8 +222,8 @@ class VariableSizeImageTransform(object):
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
target_size: tuple[int, int],
max_upscaling_size: int | None,
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
@ -289,10 +288,10 @@ class VariableSizeImageTransform(object):
def get_best_fit(
self,
image_size: Tuple[int, int],
image_size: tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
@ -392,7 +391,7 @@ class VariableSizeImageTransform(object):
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
) -> tuple[torch.Tensor, tuple[int, int]]:
"""
Args:
image (PIL.Image): Image to be resized.

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
from typing import List, Optional
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""".strip("\n")
)
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(

View file

@ -7,7 +7,6 @@
import textwrap
from io import BytesIO
from pathlib import Path
from typing import List
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator,
@ -23,7 +22,7 @@ from ..prompt_format import (
THIS_DIR = Path(__file__).parent
def usecases(base_model: bool = False) -> List[UseCase | str]:
def usecases(base_model: bool = False) -> list[UseCase | str]:
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
img_small_dog = f.read()
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:

View file

@ -6,7 +6,7 @@
import logging
import os
from typing import Callable, Optional
from collections.abc import Callable
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
quantization_mode: str | None = None,
fp8_activation_scale_ub: float | None = 1200.0,
use_rich_progress: bool = True,
) -> Transformer:
from ...quantize_impls import (
@ -213,7 +213,7 @@ def logging_callbacks(
)
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
def update_status(message: str | None, completed: int | None = None) -> None:
if use_rich_progress:
if message is not None:
progress.update(task_id, status=message)

View file

@ -5,18 +5,11 @@
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
@ -114,7 +107,7 @@ class Tokenizer:
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
special_tokens: dict[str, int]
num_reserved_special_tokens = 2048
@ -182,9 +175,9 @@ class Tokenizer:
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
allowed_special: Literal["all"] | Set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = (),
) -> list[int]:
"""
Encodes a string into a list of token IDs.
@ -217,7 +210,7 @@ class Tokenizer:
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
t: list[int] = []
for substr in substrs:
t.extend(
self.model.encode(
@ -243,7 +236,7 @@ class Tokenizer:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
return self.model.decode(cast(list[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import math
from typing import Any, Callable, Dict, List
from collections.abc import Callable
from typing import Any
import torch
import torch.nn as nn
@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
original_sd = self.state_dict()
@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module):
# each image is a tensor of shape [num_tiles, C, H, W]
def forward(
self,
image_batch: List[List[torch.Tensor]],
image_batch: list[list[torch.Tensor]],
image_mask: torch.Tensor,
h_ref: torch.Tensor,
) -> torch.Tensor:

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from collections.abc import Callable
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int],
bias: bool | None = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module):
def attention(
self,
x: torch.Tensor,
freq_cis: Optional[torch.Tensor] = None,
freq_cis: torch.Tensor | None = None,
):
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
freq_cis: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
freq_cis: torch.Tensor | None = None,
):
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
@ -210,8 +211,8 @@ class PackingIndex:
class VisionEncoder(nn.Module):
def __init__(
self,
image_size: Tuple[int, int],
patch_size: Tuple[int, int],
image_size: tuple[int, int],
patch_size: tuple[int, int],
dim: int,
layers: int,
heads: int,
@ -299,13 +300,13 @@ class VisionEncoder(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")