forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
316 lines
12 KiB
Python
316 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# type: ignore
|
|
import os
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
import torch
|
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
|
from torch import Tensor, nn
|
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
|
|
|
from ...datatypes import QuantizationMode
|
|
from ...quantize_impls import (
|
|
Fp8ScaledWeights,
|
|
ffn_swiglu,
|
|
load_fp8,
|
|
quantize_fp8,
|
|
)
|
|
from ..model import Transformer, TransformerBlock
|
|
from ..multimodal.model import CrossAttentionTransformer
|
|
|
|
|
|
def swiglu_wrapper(
|
|
self,
|
|
x: Tensor,
|
|
):
|
|
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
|
return reduce_from_model_parallel_region(out)
|
|
|
|
|
|
def convert_to_quantized_model(
|
|
model: Transformer | CrossAttentionTransformer,
|
|
checkpoint_dir: str,
|
|
quantization_mode: Optional[str] = None,
|
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
|
device: Optional[torch.device] = None,
|
|
) -> Transformer | CrossAttentionTransformer:
|
|
if quantization_mode == QuantizationMode.fp8_mixed:
|
|
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
|
elif quantization_mode == QuantizationMode.int4_mixed:
|
|
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
|
else:
|
|
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
|
|
|
|
|
def convert_to_fp8_quantized_model(
|
|
model: Transformer,
|
|
checkpoint_dir: str,
|
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
|
device: Optional[torch.device] = None,
|
|
) -> Transformer:
|
|
# Move weights to GPU with quantization
|
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
|
if os.path.isfile(fp8_scales_path):
|
|
print("Loading fp8 scales...")
|
|
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
|
|
|
for _, block in model.named_modules():
|
|
if isinstance(block, TransformerBlock):
|
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
|
continue
|
|
|
|
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
|
for key in ("w1", "w3", "w2"):
|
|
param = getattr(block.feed_forward, key)
|
|
param.weight = load_fp8(
|
|
param.weight,
|
|
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
|
|
fp8_activation_scale_ub,
|
|
)
|
|
else:
|
|
print("Quantizing fp8 weights from bf16...")
|
|
for _, block in model.named_modules():
|
|
if isinstance(block, TransformerBlock):
|
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
|
continue
|
|
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
|
|
for key in ("w1", "w3", "w2"):
|
|
param = getattr(block.feed_forward, key)
|
|
param.weight = quantize_fp8(
|
|
param.weight,
|
|
fp8_activation_scale_ub,
|
|
output_device=device,
|
|
)
|
|
|
|
for _, parameter in model.named_parameters():
|
|
if not isinstance(parameter, Fp8ScaledWeights):
|
|
parameter.data = parameter.to(device=device)
|
|
return model
|
|
|
|
|
|
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|
"""
|
|
Int8DynActInt4WeightLinear with LoRA adaptor.
|
|
|
|
Args:
|
|
in_features: Number of input features.
|
|
out_features: Number of output features.
|
|
bias: Whether to use bias.
|
|
device: Device to use.
|
|
group_size: Group size for quantization.
|
|
precision: Precision of quantization.
|
|
scales_precision: Precision of scales.
|
|
lora_rank: Rank of LoRA adaptor.
|
|
lora_scale: Scale of LoRA adaptor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias=False,
|
|
device=None,
|
|
# quantization parameters
|
|
group_size: int = 256,
|
|
precision: torch.dtype = torch.float32,
|
|
scales_precision: torch.dtype = torch.float32,
|
|
# LoRA parameters
|
|
lora_rank: Optional[int] = None,
|
|
lora_scale: Optional[float] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
in_features,
|
|
out_features,
|
|
bias=bias,
|
|
device=device,
|
|
groupsize=group_size,
|
|
precision=precision,
|
|
scales_precision=scales_precision,
|
|
)
|
|
self.lora_scale: Optional[float] = None
|
|
self.adaptor: Optional[nn.Sequential] = None
|
|
if lora_rank is not None:
|
|
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
|
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
|
self.adaptor = nn.Sequential()
|
|
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
|
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
|
self.lora_scale = lora_scale
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(
|
|
self,
|
|
state_dict: Dict[str, Any],
|
|
prefix: str,
|
|
local_metadata: Dict[str, Any],
|
|
strict: bool,
|
|
missing_keys: List[str],
|
|
unexpected_keys: List[str],
|
|
error_msgs: List[str],
|
|
) -> None:
|
|
"""A hook to load the quantized weights from the state dict."""
|
|
if prefix + "zeros" not in state_dict:
|
|
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
|
assert prefix + "scales" in state_dict
|
|
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
|
|
|
|
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
|
module_out = super().forward(input_)
|
|
if self.adaptor is not None:
|
|
adaptor_out = self.adaptor(input_) * self.lora_scale
|
|
return module_out + adaptor_out
|
|
return module_out
|
|
|
|
|
|
class Int8WeightEmbedding(torch.nn.Embedding):
|
|
"""An embedding layer to load int8 weights.
|
|
|
|
Args:
|
|
num_embeddings: Number of embeddings.
|
|
embedding_dim: Embedding dimension.
|
|
padding_idx: Padding index.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: int,
|
|
device=None,
|
|
) -> None:
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(
|
|
self,
|
|
state_dict: Dict[str, Any],
|
|
prefix: str,
|
|
local_metadata: Dict[str, Any],
|
|
strict: bool,
|
|
missing_keys: List[str],
|
|
unexpected_keys: List[str],
|
|
error_msgs: List[str],
|
|
) -> None:
|
|
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
|
weights = state_dict.pop(prefix + "weight")
|
|
scales = state_dict.pop(prefix + "scales")
|
|
state_dict[prefix + "weight"] = weights * scales
|
|
|
|
|
|
class Int8WeightLinear(torch.nn.Linear):
|
|
"""A linear layer to load int8 weights.
|
|
|
|
Args:
|
|
in_features: Number of input features.
|
|
out_features: Number of output features.
|
|
bias: Whether to use bias.
|
|
"""
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
|
|
super().__init__(in_features, out_features, bias, device=device)
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(
|
|
self,
|
|
state_dict: Dict[str, Any],
|
|
prefix: str,
|
|
local_metadata: Dict[str, Any],
|
|
strict: bool,
|
|
missing_keys: List[str],
|
|
unexpected_keys: List[str],
|
|
error_msgs: List[str],
|
|
) -> None:
|
|
"""A hook to load the quantized linear weight and scales from the state dict."""
|
|
weights = state_dict.pop(prefix + "weight")
|
|
scales = state_dict.pop(prefix + "scales")
|
|
state_dict[prefix + "weight"] = weights * scales
|
|
|
|
|
|
def _prepare_model_int4_weight_int8_dynamic_activation(
|
|
model: torch.nn.Module,
|
|
group_size: int,
|
|
lora_rank: Optional[int],
|
|
lora_scale: Optional[float],
|
|
):
|
|
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
|
|
|
Note that the weights of embedding and output layers are quantized to int8.
|
|
"""
|
|
device = None
|
|
for module_name, module in model.named_children():
|
|
if module_name == "output":
|
|
quantized_module = Int8WeightLinear(
|
|
in_features=module.in_features,
|
|
out_features=module.out_features,
|
|
bias=module.bias,
|
|
device=device,
|
|
)
|
|
del module
|
|
setattr(model, module_name, quantized_module)
|
|
elif module_name == "tok_embeddings":
|
|
quantized_module = Int8WeightEmbedding(
|
|
num_embeddings=module.num_embeddings,
|
|
embedding_dim=module.embedding_dim,
|
|
padding_idx=module.padding_idx,
|
|
device=device,
|
|
)
|
|
del module
|
|
setattr(model, module_name, quantized_module)
|
|
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
|
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
|
in_features=module.in_features,
|
|
out_features=module.out_features,
|
|
bias=False,
|
|
group_size=group_size,
|
|
lora_rank=lora_rank,
|
|
lora_scale=lora_scale,
|
|
device=device,
|
|
)
|
|
del module
|
|
setattr(model, module_name, quantized_module)
|
|
else:
|
|
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
|
|
|
|
return model
|
|
|
|
|
|
def convert_to_int4_quantized_model(
|
|
model: Transformer | CrossAttentionTransformer,
|
|
checkpoint_dir: str,
|
|
device: Optional[torch.device] = None,
|
|
) -> Transformer | CrossAttentionTransformer:
|
|
"""Convert the model to int4 quantized model."""
|
|
model_args = model.params
|
|
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
|
quantization_args = model_args.quantization_args
|
|
if quantization_args.scheme is None:
|
|
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
|
|
|
|
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
|
raise NotImplementedError(
|
|
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
|
)
|
|
|
|
group_size = model_args.quantization_args.group_size
|
|
if group_size is None:
|
|
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
|
|
|
|
if model_args.lora_args is None:
|
|
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
|
lora_rank = None
|
|
lora_scale = None
|
|
else:
|
|
lora_rank = model_args.lora_args.rank
|
|
lora_scale = model_args.lora_args.scale
|
|
|
|
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
|
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|