From 36a31fe5dd3947a163d94fce7a68484beb35ded1 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Wed, 9 Apr 2025 15:00:12 -0700 Subject: [PATCH] fix: on-the-fly int4 quantize parameter (#1920) Mirror to https://github.com/meta-llama/llama-models/pull/324 with some clean up ``` with-proxy pip install -e . export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct export QUANTIZATION_TYPE=int4_mixed with-proxy llama stack build --run --template meta-reference-gpu ``` # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --- .../models/llama/llama4/quantization/loader.py | 2 +- llama_stack/models/llama/quantize_impls.py | 18 +----------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index b50432896..f11d83c60 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -91,7 +91,7 @@ def convert_to_quantized_model( log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + return quantize_int4(weight, output_device=torch.device("cuda")) else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index 6e1d15cf6..a5da01588 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -65,7 +65,7 @@ class Int4Weights( Int4ScaledWeights, collections.namedtuple( "Int4Weights", - ["weight", "scale", "zero_point", "shape", "activation_scale_ub"], + ["weight", "scale", "zero_point", "shape"], ), ): pass @@ -184,20 +184,13 @@ def quantize_fp8( @torch.inference_mode() def quantize_int4( w: Tensor, - fp8_activation_scale_ub: float, output_device: Optional[torch.device] = None, ) -> Int4Weights: """Quantize [n, k/2] weight tensor. Args: w (Tensor): [n, k/2] input high precision tensor to quantize. - fp8_activation_scale_ub (float): Upper bound for activation max. """ - activation_scale_ub = torch.tensor( - [fp8_activation_scale_ub], - dtype=torch.float, - device=output_device, - ) if w.ndim >= 3: wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False) wq = torch.stack([pack_int4(i) for i in wq], dim=0) @@ -212,7 +205,6 @@ def quantize_int4( scale=scale.to(output_device), zero_point=zero_point.to(output_device), shape=wq.shape, - activation_scale_ub=activation_scale_ub, ) @@ -247,26 +239,18 @@ def load_int4( w: Tensor, scale: Tensor, zero_point: Tensor, - fp8_activation_scale_ub: float, output_device: Optional[torch.device] = None, ) -> Int4Weights: """Load INT4 [n, k/2] weight tensor. Args: w (Tensor): [n, k/2] input INT4. - fp8_activation_scale_ub (float): Upper bound for activation max. """ - activation_scale_ub = torch.tensor( - [fp8_activation_scale_ub], - dtype=torch.float, - device=output_device, - ) return Int4Weights( weight=w.to(torch.int8).to(device=output_device), scale=scale.to(device=output_device), zero_point=zero_point.to(device=output_device), shape=w.shape, - activation_scale_ub=activation_scale_ub, )