mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
# type: ignore
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
|
@ -37,9 +37,9 @@ def swiglu_wrapper(
|
|||
def convert_to_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
quantization_mode: str | None = None,
|
||||
fp8_activation_scale_ub: float | None = 1200.0,
|
||||
device: torch.device | None = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||
|
@ -52,8 +52,8 @@ def convert_to_quantized_model(
|
|||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
fp8_activation_scale_ub: float | None = 1200.0,
|
||||
device: torch.device | None = None,
|
||||
) -> Transformer:
|
||||
# Move weights to GPU with quantization
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
|
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
precision: torch.dtype = torch.float32,
|
||||
scales_precision: torch.dtype = torch.float32,
|
||||
# LoRA parameters
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
lora_rank: int | None = None,
|
||||
lora_scale: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features,
|
||||
|
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
self.lora_scale: Optional[float] = None
|
||||
self.adaptor: Optional[nn.Sequential] = None
|
||||
self.lora_scale: float | None = None
|
||||
self.adaptor: nn.Sequential | None = None
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
|
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized weights from the state dict."""
|
||||
if prefix + "zeros" not in state_dict:
|
||||
|
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
|
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
|
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
|
|||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model: torch.nn.Module,
|
||||
group_size: int,
|
||||
lora_rank: Optional[int],
|
||||
lora_scale: Optional[float],
|
||||
lora_rank: int | None,
|
||||
lora_scale: float | None,
|
||||
):
|
||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||
|
||||
|
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
||||
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
|
||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
|
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
def convert_to_int4_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
device: Optional[torch.device] = None,
|
||||
device: torch.device | None = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
model_args = model.params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue