mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Added hadamard transform for spinquant (#326)
* Added hadamard transform for spinquant * Changed from config to model_args * Added an assertion for model args * Use enum.value to check against str * pre-commit --------- Co-authored-by: Sachin Mehta <sacmehta@fb.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
07f9bf723f
commit
c05fbf14b3
3 changed files with 109 additions and 14 deletions
|
@ -152,13 +152,22 @@ class Llama:
|
|||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
assert (
|
||||
config.quantization.scheme is not None
|
||||
), "Please specify a quantization scheme."
|
||||
|
||||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if (
|
||||
model_args.quantization_args is not None
|
||||
and model_args.quantization_args.spinquant
|
||||
):
|
||||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from .quantization.hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
add_hadamard_transform_for_spinquant(model)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently int4 and fp8 are the only supported quantization methods."
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Hadamard transform.
|
||||
|
||||
This function performs the Hadamard transform on the input tensor 'x'.
|
||||
The Hadamard transform is a linear transformation that multiplies the input
|
||||
tensor by the Hadamard matrix of dimension n x n, where n is the size of
|
||||
the last dimension of the input tensor.
|
||||
"""
|
||||
*_, n = x.shape
|
||||
m = int(math.log2(n))
|
||||
assert n == 1 << m, "n must be a power of 2"
|
||||
x = x[..., None]
|
||||
inv_sqrt2 = 0.5**0.5
|
||||
for _ in range(m):
|
||||
top = x[..., ::2, :] + x[..., 1::2, :]
|
||||
bot = x[..., ::2, :] - x[..., 1::2, :]
|
||||
x = torch.cat((top, bot), dim=-1)
|
||||
x *= inv_sqrt2
|
||||
res = x.squeeze(-2)
|
||||
return res
|
||||
|
||||
|
||||
class HadamardModule(torch.nn.Module):
|
||||
"""A module that applies the Hadamard transform to the input tensor.
|
||||
|
||||
Args:
|
||||
group_size: The size of the groups that the input tensor will be divided into
|
||||
before applying the Hadamard transform.
|
||||
"""
|
||||
|
||||
def __init__(self, group_size: int) -> None:
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
reshape_back = False
|
||||
orig_shape = x.shape
|
||||
if self.group_size != x.shape[-1]:
|
||||
reshape_back = True
|
||||
x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size)
|
||||
x = hadamard_transform(x)
|
||||
if reshape_back:
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def add_hadamard_transform_for_spinquant(
|
||||
model: torch.nn.Module, prefix: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
|
||||
This function recursively traverses the model's children and looks for layers that match the pattern
|
||||
"layers.<digit>.feed_forward.w2", where <digit> is one or more digits. When such a layer is found,
|
||||
it is replaced with a new sequential module that consists of a HadamardModule followed by the original
|
||||
layer. The HadamardModule applies the Hadamard transform to the input tensor.
|
||||
|
||||
See `SpinQuant <https://arxiv.org/abs/2405.16406>_` paper for more details.
|
||||
|
||||
Args:
|
||||
model: An instance of 'torch.nn.Module' (e.g., Transformer model).
|
||||
prefix: A string prefix to add to the full name of each child module.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2"
|
||||
for module_name, module in model.named_children():
|
||||
child_full_name = prefix + "." + module_name
|
||||
if re.search(pattern_last_linear_ffn, child_full_name):
|
||||
new_module = nn.Sequential(
|
||||
HadamardModule(group_size=module.in_features), module
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, new_module)
|
||||
else:
|
||||
add_hadamard_transform_for_spinquant(
|
||||
module, (prefix + "." if prefix else prefix) + module_name
|
||||
)
|
|
@ -26,7 +26,6 @@ from torch import nn, Tensor
|
|||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.apis.inference.inference import Int4QuantizationConfig
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
|
@ -309,21 +308,16 @@ def convert_to_int4_quantized_model(
|
|||
) -> Transformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
quant_config = config.quantization
|
||||
if not isinstance(quant_config, Int4QuantizationConfig):
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
if quant_config.type != QuantizationType.int4.value:
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
quantization_args = model_args.quantization_args
|
||||
|
||||
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
|
||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
||||
)
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue