Added hadamard transform for spinquant (#326)

* Added hadamard transform for spinquant

* Changed from config to model_args

* Added an assertion for model args

* Use enum.value to check against str

* pre-commit

---------

Co-authored-by: Sachin Mehta <sacmehta@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Sachin Mehta 2024-10-25 12:58:48 -07:00 committed by GitHub
parent 07f9bf723f
commit c05fbf14b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 109 additions and 14 deletions

View file

@ -26,7 +26,6 @@ from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.inference import Int4QuantizationConfig
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
@ -309,21 +308,16 @@ def convert_to_int4_quantized_model(
) -> Transformer:
"""Convert the model to int4 quantized model."""
quant_config = config.quantization
if not isinstance(quant_config, Int4QuantizationConfig):
raise ValueError("Only int4 quantization is supported")
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
if quant_config.type != QuantizationType.int4.value:
raise ValueError("Only int4 quantization is supported")
quantization_args = model_args.quantization_args
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(