From 8e0a4e2885641c7d518b1592f8185a78d2e8b1d5 Mon Sep 17 00:00:00 2001 From: Sachin Mehta Date: Fri, 25 Oct 2024 12:33:27 -0700 Subject: [PATCH] Added an assertion for model args --- .../providers/impls/meta_reference/inference/generation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index ad1b9d909..96855b041 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -155,7 +155,10 @@ class Llama: model = convert_to_int4_quantized_model(model, model_args, config) model.load_state_dict(state_dict, strict=True) - if model_args.quantization_args.spinquant: + if ( + model_args.quantization_args is not None + and model_args.quantization_args.spinquant + ): # Add a wrapper for adding hadamard transform for spinquant. # This needs to be done after loading the state dict otherwise an error will be raised while # loading the state dict.