Added hadamard transform for spinquant (#326)

* Added hadamard transform for spinquant

* Changed from config to model_args

* Added an assertion for model args

* Use enum.value to check against str

* pre-commit

---------

Co-authored-by: Sachin Mehta <sacmehta@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Sachin Mehta 2024-10-25 12:58:48 -07:00 committed by GitHub
parent 07f9bf723f
commit c05fbf14b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 109 additions and 14 deletions

View file

@ -152,13 +152,22 @@ class Llama:
elif isinstance(config.quantization, Int4QuantizationConfig): elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model from .quantization.loader import convert_to_int4_quantized_model
assert (
config.quantization.scheme is not None
), "Please specify a quantization scheme."
model = Transformer(model_args) model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config) model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from .quantization.hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
add_hadamard_transform_for_spinquant(model)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods." "Currently int4 and fp8 are the only supported quantization methods."

View file

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
import re
import torch
from torch import nn
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
"""Hadamard transform.
This function performs the Hadamard transform on the input tensor 'x'.
The Hadamard transform is a linear transformation that multiplies the input
tensor by the Hadamard matrix of dimension n x n, where n is the size of
the last dimension of the input tensor.
"""
*_, n = x.shape
m = int(math.log2(n))
assert n == 1 << m, "n must be a power of 2"
x = x[..., None]
inv_sqrt2 = 0.5**0.5
for _ in range(m):
top = x[..., ::2, :] + x[..., 1::2, :]
bot = x[..., ::2, :] - x[..., 1::2, :]
x = torch.cat((top, bot), dim=-1)
x *= inv_sqrt2
res = x.squeeze(-2)
return res
class HadamardModule(torch.nn.Module):
"""A module that applies the Hadamard transform to the input tensor.
Args:
group_size: The size of the groups that the input tensor will be divided into
before applying the Hadamard transform.
"""
def __init__(self, group_size: int) -> None:
super().__init__()
self.group_size = group_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
reshape_back = False
orig_shape = x.shape
if self.group_size != x.shape[-1]:
reshape_back = True
x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size)
x = hadamard_transform(x)
if reshape_back:
x = x.reshape(orig_shape)
return x
def add_hadamard_transform_for_spinquant(
model: torch.nn.Module, prefix: str = ""
) -> None:
"""
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern
"layers.<digit>.feed_forward.w2", where <digit> is one or more digits. When such a layer is found,
it is replaced with a new sequential module that consists of a HadamardModule followed by the original
layer. The HadamardModule applies the Hadamard transform to the input tensor.
See `SpinQuant <https://arxiv.org/abs/2405.16406>_` paper for more details.
Args:
model: An instance of 'torch.nn.Module' (e.g., Transformer model).
prefix: A string prefix to add to the full name of each child module.
Returns:
None
"""
pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2"
for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential(
HadamardModule(group_size=module.in_features), module
)
del module
setattr(model, module_name, new_module)
else:
add_hadamard_transform_for_spinquant(
module, (prefix + "." if prefix else prefix) + module_name
)

View file

@ -26,7 +26,6 @@ from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.inference import Int4QuantizationConfig
from llama_stack.providers.impls.meta_reference.inference.config import ( from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig, MetaReferenceQuantizedInferenceConfig,
@ -309,21 +308,16 @@ def convert_to_int4_quantized_model(
) -> Transformer: ) -> Transformer:
"""Convert the model to int4 quantized model.""" """Convert the model to int4 quantized model."""
quant_config = config.quantization if model_args.quantization_args is None:
if not isinstance(quant_config, Int4QuantizationConfig): raise ValueError("'quantization_args' cannot be None. Please specify it.")
raise ValueError("Only int4 quantization is supported")
if quant_config.type != QuantizationType.int4.value: quantization_args = model_args.quantization_args
raise ValueError("Only int4 quantization is supported")
if quant_config.scheme != "int4_weight_int8_dynamic_activation": if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError( raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
) )
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
group_size = model_args.quantization_args.group_size group_size = model_args.quantization_args.group_size
if group_size is None: if group_size is None:
raise ValueError( raise ValueError(