Added hadamard transform for spinquant (#326)

* Added hadamard transform for spinquant

* Changed from config to model_args

* Added an assertion for model args

* Use enum.value to check against str

* pre-commit

---------

Co-authored-by: Sachin Mehta <sacmehta@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Sachin Mehta 2024-10-25 12:58:48 -07:00 committed by GitHub
parent 07f9bf723f
commit c05fbf14b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 109 additions and 14 deletions

View file

@ -152,13 +152,22 @@ class Llama:
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
assert (
config.quantization.scheme is not None
), "Please specify a quantization scheme."
model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from .quantization.hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."