Use enum.value to check against str

This commit is contained in:
Ashwin Bharambe 2024-10-25 12:53:32 -07:00
parent 8e0a4e2885
commit d0b6894a41

View file

@ -20,18 +20,17 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.inference import Int4QuantizationConfig
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
from termcolor import cprint from termcolor import cprint
from torch import nn, Tensor from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
def swiglu_wrapper( def swiglu_wrapper(
self, self,
@ -314,7 +313,7 @@ def convert_to_int4_quantized_model(
quantization_args = model_args.quantization_args quantization_args = model_args.quantization_args
if quantization_args.scheme != "int4_weight_int8_dynamic_activation": if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError( raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
) )