From db4f18099f4a50509ae30c39db98230f623a9e35 Mon Sep 17 00:00:00 2001 From: Sachin Mehta Date: Fri, 25 Oct 2024 12:24:50 -0700 Subject: [PATCH] Changed from config to model_args --- .../meta_reference/inference/generation.py | 6 +----- .../inference/quantization/loader.py | 21 +++++++------------ 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 7ac9632fb..ad1b9d909 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -151,15 +151,11 @@ class Llama: elif isinstance(config.quantization, Int4QuantizationConfig): from .quantization.loader import convert_to_int4_quantized_model - assert ( - config.quantization.scheme is not None - ), "Please specify a quantization scheme." - model = Transformer(model_args) model = convert_to_int4_quantized_model(model, model_args, config) model.load_state_dict(state_dict, strict=True) - if config.quantization.spinquant: + if model_args.quantization_args.spinquant: # Add a wrapper for adding hadamard transform for spinquant. # This needs to be done after loading the state dict otherwise an error will be raised while # loading the state dict. diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index e07c9fa3b..5ee9c15ee 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -20,10 +20,6 @@ from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model -from termcolor import cprint -from torch import nn, Tensor - -from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference.inference import Int4QuantizationConfig @@ -31,6 +27,10 @@ from llama_stack.apis.inference.inference import Int4QuantizationConfig from llama_stack.providers.impls.meta_reference.inference.config import ( MetaReferenceQuantizedInferenceConfig, ) +from termcolor import cprint +from torch import nn, Tensor + +from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear def swiglu_wrapper( @@ -309,21 +309,16 @@ def convert_to_int4_quantized_model( ) -> Transformer: """Convert the model to int4 quantized model.""" - quant_config = config.quantization - if not isinstance(quant_config, Int4QuantizationConfig): - raise ValueError("Only int4 quantization is supported") + if model_args.quantization_args is None: + raise ValueError("'quantization_args' cannot be None. Please specify it.") - if quant_config.type != QuantizationType.int4.value: - raise ValueError("Only int4 quantization is supported") + quantization_args = model_args.quantization_args - if quant_config.scheme != "int4_weight_int8_dynamic_activation": + if quantization_args.scheme != "int4_weight_int8_dynamic_activation": raise NotImplementedError( "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." ) - if model_args.quantization_args is None: - raise ValueError("'quantization_args' cannot be None. Please specify it.") - group_size = model_args.quantization_args.group_size if group_size is None: raise ValueError(