mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix: on-the-fly int4 quantize parameter
Before this PR: ``` [rank1]: TypeError: quantize_int4() got multiple values for argument 'output_device' ```
This commit is contained in:
parent
e2299291c4
commit
c8a0b110c0
1 changed files with 2 additions and 2 deletions
|
@ -10,7 +10,7 @@ from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from torch import Tensor, nn
|
from torch import nn, Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...datatypes import QuantizationMode
|
from ...datatypes import QuantizationMode
|
||||||
|
@ -91,7 +91,7 @@ def convert_to_quantized_model(
|
||||||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||||
|
|
||||||
def apply_quantization(_, weight):
|
def apply_quantization(_, weight):
|
||||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
return quantize_int4(weight, output_device=torch.device("cuda"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue