fix: on-the-fly int4 quantize parameter

Before this PR:
```
[rank1]: TypeError: quantize_int4() got multiple values for argument 'output_device'
```
This commit is contained in:
jiawenliu64 2025-04-09 13:35:11 -07:00
parent e2299291c4
commit c8a0b110c0

View file

@ -10,7 +10,7 @@ from typing import Callable, Optional
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn
from torch import nn, Tensor
from torch.nn import functional as F
from ...datatypes import QuantizationMode
@ -91,7 +91,7 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
return quantize_int4(weight, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")