Added an assertion for model args

This commit is contained in:
Sachin Mehta 2024-10-25 12:33:27 -07:00
parent db4f18099f
commit 8e0a4e2885

View file

@ -155,7 +155,10 @@ class Llama:
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if model_args.quantization_args.spinquant:
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.