chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -89,7 +89,7 @@ def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -174,13 +174,13 @@ class Attention(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
@ -200,7 +200,7 @@ class Attention(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor],
local_attn_mask: Optional[torch.Tensor],
global_attn_mask: torch.Tensor | None,
local_attn_mask: torch.Tensor | None,
):
# The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used
@ -374,13 +374,13 @@ class Transformer(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs")