several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -6,20 +6,29 @@
import logging
import os
from typing import Optional
from typing import Callable, Optional
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor
from torch import Tensor, nn
from torch.nn import functional as F
from ..generation import QuantizationMode
from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
def swiglu_wrapper_no_reduce(
self,
x: Tensor,
):
from ...quantize_impls import ffn_swiglu
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
def experts_batched_swiglu_wrapper(
self,
x: Tensor, # (e, g, D)
@ -51,24 +60,30 @@ def convert_to_quantized_model(
rank = get_model_parallel_rank()
def should_quantize_block(block: nn.Module) -> bool:
if not isinstance(block, TransformerBlock):
return False
is_moe = isinstance(block.feed_forward, MoE)
if quantization_mode == QuantizationMode.fp8_mixed:
# skip quantization on first and last layers
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
return is_moe
use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True)
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
def apply_quantization(key, weight):
scale = int4_scales[key]
zero_point = int4_zero_points[key]
return load_int4(
weight,
scale,
zero_point,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
@ -76,7 +91,8 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
return quantize_int4(weight, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path):
@ -104,33 +120,38 @@ def convert_to_quantized_model(
progress.start()
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
# Skip quantization on first and last layers
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
if not should_quantize_block(block):
continue
# Skip quantization on dense layers
if not isinstance(block.feed_forward, MoE):
continue
update_status(f"Rank {rank} - Layer {block.layer_id}")
update_status(f"Rank {rank} - Layer {block.layer_id}")
# Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
# Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(
f"{prefix}.experts.{key}",
param.transpose(1, 2).contiguous(),
),
)
if quantization_mode == QuantizationMode.int4_mixed:
# Quantize shared experts
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
)
param = getattr(moe.shared_expert, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
processed_blocks += 1
update_status(message=None, completed=processed_blocks)
processed_blocks += 1
update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA")
@ -149,7 +170,12 @@ def convert_to_quantized_model(
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
def logging_callbacks(
use_rich_progress: bool,
rank: int,
model: Transformer,
should_quantize_block: Callable[[nn.Module], bool],
):
console = None
if use_rich_progress:
from rich.console import Console
@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message)
total_blocks = sum(
1
for _, block in model.named_modules()
if (
isinstance(block, TransformerBlock)
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
and isinstance(block.feed_forward, MoE)
)
)
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
progress = None
if use_rich_progress:
from rich.progress import (