mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-30 03:44:20 +00:00
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -19,9 +19,7 @@ try:
|
|||
|
||||
log.info("Using efficient FP8 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error(
|
||||
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
|
||||
)
|
||||
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
|
@ -60,14 +58,8 @@ def ffn_swiglu(
|
|||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (
|
||||
isinstance(w1, Fp8ScaledWeights)
|
||||
and isinstance(w3, Fp8ScaledWeights)
|
||||
and isinstance(w2, Fp8ScaledWeights)
|
||||
):
|
||||
return ffn_swiglu_fp8_dynamic(
|
||||
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
||||
)
|
||||
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
|
||||
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
|
@ -146,12 +138,8 @@ def fc_fp8_dynamic(
|
|||
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||
"""
|
||||
if isinstance(w, Fp8RowwiseWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x, num_tokens, activation_scale_ub
|
||||
)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
xq, w.weight, x_scale, w.scale, use_fast_accum=True
|
||||
)
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
|
|
@ -17,8 +17,7 @@ from torch import Tensor
|
|||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
"Skip when H100 is not available",
|
||||
)
|
||||
class FP8Tests(unittest.TestCase):
|
||||
|
|
|
@ -57,9 +57,7 @@ class HadamardModule(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def add_hadamard_transform_for_spinquant(
|
||||
model: torch.nn.Module, prefix: str = ""
|
||||
) -> None:
|
||||
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
|
||||
"""
|
||||
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
|
||||
This function recursively traverses the model's children and looks for layers that match the pattern
|
||||
|
@ -81,12 +79,8 @@ def add_hadamard_transform_for_spinquant(
|
|||
for module_name, module in model.named_children():
|
||||
child_full_name = prefix + "." + module_name
|
||||
if re.search(pattern_last_linear_ffn, child_full_name):
|
||||
new_module = nn.Sequential(
|
||||
HadamardModule(group_size=module.in_features), module
|
||||
)
|
||||
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
|
||||
del module
|
||||
setattr(model, module_name, new_module)
|
||||
else:
|
||||
add_hadamard_transform_for_spinquant(
|
||||
module, (prefix + "." if prefix else prefix) + module_name
|
||||
)
|
||||
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)
|
||||
|
|
|
@ -63,12 +63,8 @@ def convert_to_fp8_quantized_model(
|
|||
# Move weights to GPU with quantization
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
log.info("Loading fp8 scales...")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
fp8_scales_path
|
||||
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for block in model.layers:
|
||||
|
@ -81,9 +77,7 @@ def convert_to_fp8_quantized_model(
|
|||
param = getattr(block.feed_forward, key)
|
||||
param.weight = load_fp8(
|
||||
param.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
|
@ -172,9 +166,7 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
if prefix + "zeros" not in state_dict:
|
||||
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
||||
assert prefix + "scales" in state_dict
|
||||
state_dict[prefix + "zeros"] = torch.zeros_like(
|
||||
state_dict[prefix + "scales"]
|
||||
)
|
||||
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
module_out = super().forward(input_)
|
||||
|
@ -229,9 +221,7 @@ class Int8WeightLinear(torch.nn.Linear):
|
|||
bias: Whether to use bias.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = True, device=None
|
||||
) -> None:
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
|
||||
super().__init__(in_features, out_features, bias, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
@ -295,9 +285,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
else:
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
module, group_size, lora_rank, lora_scale
|
||||
)
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -321,9 +309,7 @@ def convert_to_int4_quantized_model(
|
|||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"'group_size' cannot be None in 'quantization_args'. Please specify it."
|
||||
)
|
||||
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
|
||||
|
||||
if model_args.lora_args is None:
|
||||
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
||||
|
@ -333,8 +319,6 @@ def convert_to_int4_quantized_model(
|
|||
lora_rank = model_args.lora_args.rank
|
||||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model, group_size, lora_rank, lora_scale
|
||||
)
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return model.to(device)
|
||||
|
|
|
@ -76,9 +76,9 @@ def main(
|
|||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
assert model_parallel_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
)
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
|
@ -90,9 +90,9 @@ def main(
|
|||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
assert model_args.vocab_size == tokenizer.n_words, (
|
||||
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
)
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
@ -106,9 +106,7 @@ def main(
|
|||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
log.info(ckpt_path)
|
||||
assert (
|
||||
quantized_ckpt_dir is not None
|
||||
), "QUantized checkpoint directory should not be None"
|
||||
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
|
||||
fp8_scales = {}
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
|
@ -122,9 +120,7 @@ def main(
|
|||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
|
@ -133,9 +129,7 @@ def main(
|
|||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
|
@ -144,13 +138,9 @@ def main(
|
|||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_scales_path = os.path.join(
|
||||
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
torch.save(fp8_scales, fp8_scales_path)
|
||||
|
||||
ckpt_path = os.path.join(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue