Changed from config to model_args

This commit is contained in:
Sachin Mehta 2024-10-25 12:24:50 -07:00
parent 93472042f8
commit db4f18099f
2 changed files with 9 additions and 18 deletions

View file

@ -151,15 +151,11 @@ class Llama:
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
assert (
config.quantization.scheme is not None
), "Please specify a quantization scheme."
model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if config.quantization.spinquant:
if model_args.quantization_args.spinquant:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.