diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index b50432896..f11d83c60 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -91,7 +91,7 @@ def convert_to_quantized_model( log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + return quantize_int4(weight, output_device=torch.device("cuda")) else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index 6e1d15cf6..a5da01588 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -65,7 +65,7 @@ class Int4Weights( Int4ScaledWeights, collections.namedtuple( "Int4Weights", - ["weight", "scale", "zero_point", "shape", "activation_scale_ub"], + ["weight", "scale", "zero_point", "shape"], ), ): pass @@ -184,20 +184,13 @@ def quantize_fp8( @torch.inference_mode() def quantize_int4( w: Tensor, - fp8_activation_scale_ub: float, output_device: Optional[torch.device] = None, ) -> Int4Weights: """Quantize [n, k/2] weight tensor. Args: w (Tensor): [n, k/2] input high precision tensor to quantize. - fp8_activation_scale_ub (float): Upper bound for activation max. """ - activation_scale_ub = torch.tensor( - [fp8_activation_scale_ub], - dtype=torch.float, - device=output_device, - ) if w.ndim >= 3: wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False) wq = torch.stack([pack_int4(i) for i in wq], dim=0) @@ -212,7 +205,6 @@ def quantize_int4( scale=scale.to(output_device), zero_point=zero_point.to(output_device), shape=wq.shape, - activation_scale_ub=activation_scale_ub, ) @@ -247,26 +239,18 @@ def load_int4( w: Tensor, scale: Tensor, zero_point: Tensor, - fp8_activation_scale_ub: float, output_device: Optional[torch.device] = None, ) -> Int4Weights: """Load INT4 [n, k/2] weight tensor. Args: w (Tensor): [n, k/2] input INT4. - fp8_activation_scale_ub (float): Upper bound for activation max. """ - activation_scale_ub = torch.tensor( - [fp8_activation_scale_ub], - dtype=torch.float, - device=output_device, - ) return Int4Weights( weight=w.to(torch.int8).to(device=output_device), scale=scale.to(device=output_device), zero_point=zero_point.to(device=output_device), shape=w.shape, - activation_scale_ub=activation_scale_ub, )