diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index ebce1024b..ea2ae016d 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -152,13 +152,22 @@ class Llama: elif isinstance(config.quantization, Int4QuantizationConfig): from .quantization.loader import convert_to_int4_quantized_model - assert ( - config.quantization.scheme is not None - ), "Please specify a quantization scheme." - model = Transformer(model_args) model = convert_to_int4_quantized_model(model, model_args, config) model.load_state_dict(state_dict, strict=True) + + if ( + model_args.quantization_args is not None + and model_args.quantization_args.spinquant + ): + # Add a wrapper for adding hadamard transform for spinquant. + # This needs to be done after loading the state dict otherwise an error will be raised while + # loading the state dict. + from .quantization.hadamard_utils import ( + add_hadamard_transform_for_spinquant, + ) + + add_hadamard_transform_for_spinquant(model) else: raise NotImplementedError( "Currently int4 and fp8 are the only supported quantization methods." diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/hadamard_utils.py b/llama_stack/providers/impls/meta_reference/inference/quantization/hadamard_utils.py new file mode 100644 index 000000000..f81a40951 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/hadamard_utils.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import math +import re + +import torch +from torch import nn + + +def hadamard_transform(x: torch.Tensor) -> torch.Tensor: + """Hadamard transform. + + This function performs the Hadamard transform on the input tensor 'x'. + The Hadamard transform is a linear transformation that multiplies the input + tensor by the Hadamard matrix of dimension n x n, where n is the size of + the last dimension of the input tensor. + """ + *_, n = x.shape + m = int(math.log2(n)) + assert n == 1 << m, "n must be a power of 2" + x = x[..., None] + inv_sqrt2 = 0.5**0.5 + for _ in range(m): + top = x[..., ::2, :] + x[..., 1::2, :] + bot = x[..., ::2, :] - x[..., 1::2, :] + x = torch.cat((top, bot), dim=-1) + x *= inv_sqrt2 + res = x.squeeze(-2) + return res + + +class HadamardModule(torch.nn.Module): + """A module that applies the Hadamard transform to the input tensor. + + Args: + group_size: The size of the groups that the input tensor will be divided into + before applying the Hadamard transform. + """ + + def __init__(self, group_size: int) -> None: + super().__init__() + self.group_size = group_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + reshape_back = False + orig_shape = x.shape + if self.group_size != x.shape[-1]: + reshape_back = True + x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size) + x = hadamard_transform(x) + if reshape_back: + x = x.reshape(orig_shape) + return x + + +def add_hadamard_transform_for_spinquant( + model: torch.nn.Module, prefix: str = "" +) -> None: + """ + Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model. + This function recursively traverses the model's children and looks for layers that match the pattern + "layers..feed_forward.w2", where is one or more digits. When such a layer is found, + it is replaced with a new sequential module that consists of a HadamardModule followed by the original + layer. The HadamardModule applies the Hadamard transform to the input tensor. + + See `SpinQuant _` paper for more details. + + Args: + model: An instance of 'torch.nn.Module' (e.g., Transformer model). + prefix: A string prefix to add to the full name of each child module. + + Returns: + None + """ + + pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2" + for module_name, module in model.named_children(): + child_full_name = prefix + "." + module_name + if re.search(pattern_last_linear_ffn, child_full_name): + new_module = nn.Sequential( + HadamardModule(group_size=module.in_features), module + ) + del module + setattr(model, module_name, new_module) + else: + add_hadamard_transform_for_spinquant( + module, (prefix + "." if prefix else prefix) + module_name + ) diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index e07c9fa3b..9f30354bb 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -26,7 +26,6 @@ from torch import nn, Tensor from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType -from llama_stack.apis.inference.inference import Int4QuantizationConfig from llama_stack.providers.impls.meta_reference.inference.config import ( MetaReferenceQuantizedInferenceConfig, @@ -309,21 +308,16 @@ def convert_to_int4_quantized_model( ) -> Transformer: """Convert the model to int4 quantized model.""" - quant_config = config.quantization - if not isinstance(quant_config, Int4QuantizationConfig): - raise ValueError("Only int4 quantization is supported") + if model_args.quantization_args is None: + raise ValueError("'quantization_args' cannot be None. Please specify it.") - if quant_config.type != QuantizationType.int4.value: - raise ValueError("Only int4 quantization is supported") + quantization_args = model_args.quantization_args - if quant_config.scheme != "int4_weight_int8_dynamic_activation": + if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation": raise NotImplementedError( "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." ) - if model_args.quantization_args is None: - raise ValueError("'quantization_args' cannot be None. Please specify it.") - group_size = model_args.quantization_args.group_size if group_size is None: raise ValueError(