chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -6,7 +6,7 @@
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
from typing import Any, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -37,9 +37,9 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
quantization_mode: str | None = None,
fp8_activation_scale_ub: float | None = 1200.0,
device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
@ -52,8 +52,8 @@ def convert_to_quantized_model(
def convert_to_fp8_quantized_model(
model: Transformer,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
fp8_activation_scale_ub: float | None = 1200.0,
device: torch.device | None = None,
) -> Transformer:
# Move weights to GPU with quantization
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
lora_rank: int | None = None,
lora_scale: float | None = None,
) -> None:
super().__init__(
in_features,
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
self.lora_scale: float | None = None
self.adaptor: nn.Sequential | None = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
lora_rank: int | None,
lora_scale: float | None,
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
model_args = model.params