mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 09:32:36 +00:00
Added hadamard transform for spinquant
This commit is contained in:
parent
afae4e3d8e
commit
93472042f8
2 changed files with 103 additions and 2 deletions
|
|
@ -35,12 +35,11 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
|
|
@ -159,6 +158,16 @@ class Llama:
|
|||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if config.quantization.spinquant:
|
||||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from .quantization.hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
add_hadamard_transform_for_spinquant(model)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently int4 and fp8 are the only supported quantization methods."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue