From c8a0b110c0e37f3d3f474cc6db67697513d7ffb5 Mon Sep 17 00:00:00 2001 From: jiawenliu64 Date: Wed, 9 Apr 2025 13:35:11 -0700 Subject: [PATCH] fix: on-the-fly int4 quantize parameter Before this PR: ``` [rank1]: TypeError: quantize_int4() got multiple values for argument 'output_device' ``` --- llama_stack/models/llama/llama4/quantization/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index b50432896..5d430aa60 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -10,7 +10,7 @@ from typing import Callable, Optional import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F from ...datatypes import QuantizationMode @@ -91,7 +91,7 @@ def convert_to_quantized_model( log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + return quantize_int4(weight, output_device=torch.device("cuda")) else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")