forked from phoenix-oss/llama-stack-mirror
Added hadamard transform for spinquant (#326)
* Added hadamard transform for spinquant * Changed from config to model_args * Added an assertion for model args * Use enum.value to check against str * pre-commit --------- Co-authored-by: Sachin Mehta <sacmehta@fb.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
07f9bf723f
commit
c05fbf14b3
3 changed files with 109 additions and 14 deletions
|
@ -26,7 +26,6 @@ from torch import nn, Tensor
|
|||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.apis.inference.inference import Int4QuantizationConfig
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
|
@ -309,21 +308,16 @@ def convert_to_int4_quantized_model(
|
|||
) -> Transformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
quant_config = config.quantization
|
||||
if not isinstance(quant_config, Int4QuantizationConfig):
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
if quant_config.type != QuantizationType.int4.value:
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
quantization_args = model_args.quantization_args
|
||||
|
||||
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
|
||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
||||
)
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue