forked from phoenix-oss/llama-stack-mirror
		
	Added hadamard transform for spinquant (#326)
* Added hadamard transform for spinquant * Changed from config to model_args * Added an assertion for model args * Use enum.value to check against str * pre-commit --------- Co-authored-by: Sachin Mehta <sacmehta@fb.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
		
							parent
							
								
									07f9bf723f
								
							
						
					
					
						commit
						c05fbf14b3
					
				
					 3 changed files with 109 additions and 14 deletions
				
			
		|  | @ -26,7 +26,6 @@ from torch import nn, Tensor | |||
| from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear | ||||
| 
 | ||||
| from llama_stack.apis.inference import QuantizationType | ||||
| from llama_stack.apis.inference.inference import Int4QuantizationConfig | ||||
| 
 | ||||
| from llama_stack.providers.impls.meta_reference.inference.config import ( | ||||
|     MetaReferenceQuantizedInferenceConfig, | ||||
|  | @ -309,21 +308,16 @@ def convert_to_int4_quantized_model( | |||
| ) -> Transformer: | ||||
|     """Convert the model to int4 quantized model.""" | ||||
| 
 | ||||
|     quant_config = config.quantization | ||||
|     if not isinstance(quant_config, Int4QuantizationConfig): | ||||
|         raise ValueError("Only int4 quantization is supported") | ||||
|     if model_args.quantization_args is None: | ||||
|         raise ValueError("'quantization_args' cannot be None. Please specify it.") | ||||
| 
 | ||||
|     if quant_config.type != QuantizationType.int4.value: | ||||
|         raise ValueError("Only int4 quantization is supported") | ||||
|     quantization_args = model_args.quantization_args | ||||
| 
 | ||||
|     if quant_config.scheme != "int4_weight_int8_dynamic_activation": | ||||
|     if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation": | ||||
|         raise NotImplementedError( | ||||
|             "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." | ||||
|         ) | ||||
| 
 | ||||
|     if model_args.quantization_args is None: | ||||
|         raise ValueError("'quantization_args' cannot be None. Please specify it.") | ||||
| 
 | ||||
|     group_size = model_args.quantization_args.group_size | ||||
|     if group_size is None: | ||||
|         raise ValueError( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue