diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index b50432896..5d430aa60 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -10,7 +10,7 @@ from typing import Callable, Optional import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F from ...datatypes import QuantizationMode @@ -91,7 +91,7 @@ def convert_to_quantized_model( log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + return quantize_int4(weight, output_device=torch.device("cuda")) else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")