Added hadamard transform for spinquant

This commit is contained in:
Sachin Mehta 2024-10-25 11:48:24 -07:00
parent afae4e3d8e
commit 93472042f8
2 changed files with 103 additions and 2 deletions

View file

@ -35,12 +35,11 @@ from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from .config import (
Fp8QuantizationConfig,
@ -159,6 +158,16 @@ class Llama:
model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if config.quantization.spinquant:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from .quantization.hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."