diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index ad1b9d909..96855b041 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -155,7 +155,10 @@ class Llama: model = convert_to_int4_quantized_model(model, model_args, config) model.load_state_dict(state_dict, strict=True) - if model_args.quantization_args.spinquant: + if ( + model_args.quantization_args is not None + and model_args.quantization_args.spinquant + ): # Add a wrapper for adding hadamard transform for spinquant. # This needs to be done after loading the state dict otherwise an error will be raised while # loading the state dict.