mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: on-the-fly int4 quantize parameter (#1920)
Mirror to https://github.com/meta-llama/llama-models/pull/324 with some clean up ``` with-proxy pip install -e . export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct export QUANTIZATION_TYPE=int4_mixed with-proxy llama stack build --run --template meta-reference-gpu ``` # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation)
This commit is contained in:
parent
e2299291c4
commit
36a31fe5dd
2 changed files with 2 additions and 18 deletions
|
@ -91,7 +91,7 @@ def convert_to_quantized_model(
|
||||||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||||
|
|
||||||
def apply_quantization(_, weight):
|
def apply_quantization(_, weight):
|
||||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
return quantize_int4(weight, output_device=torch.device("cuda"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||||
|
|
|
@ -65,7 +65,7 @@ class Int4Weights(
|
||||||
Int4ScaledWeights,
|
Int4ScaledWeights,
|
||||||
collections.namedtuple(
|
collections.namedtuple(
|
||||||
"Int4Weights",
|
"Int4Weights",
|
||||||
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
|
["weight", "scale", "zero_point", "shape"],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
@ -184,20 +184,13 @@ def quantize_fp8(
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def quantize_int4(
|
def quantize_int4(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: Optional[torch.device] = None,
|
||||||
) -> Int4Weights:
|
) -> Int4Weights:
|
||||||
"""Quantize [n, k/2] weight tensor.
|
"""Quantize [n, k/2] weight tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
||||||
"""
|
"""
|
||||||
activation_scale_ub = torch.tensor(
|
|
||||||
[fp8_activation_scale_ub],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=output_device,
|
|
||||||
)
|
|
||||||
if w.ndim >= 3:
|
if w.ndim >= 3:
|
||||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||||
|
@ -212,7 +205,6 @@ def quantize_int4(
|
||||||
scale=scale.to(output_device),
|
scale=scale.to(output_device),
|
||||||
zero_point=zero_point.to(output_device),
|
zero_point=zero_point.to(output_device),
|
||||||
shape=wq.shape,
|
shape=wq.shape,
|
||||||
activation_scale_ub=activation_scale_ub,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -247,26 +239,18 @@ def load_int4(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
zero_point: Tensor,
|
zero_point: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: Optional[torch.device] = None,
|
||||||
) -> Int4Weights:
|
) -> Int4Weights:
|
||||||
"""Load INT4 [n, k/2] weight tensor.
|
"""Load INT4 [n, k/2] weight tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (Tensor): [n, k/2] input INT4.
|
w (Tensor): [n, k/2] input INT4.
|
||||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
||||||
"""
|
"""
|
||||||
activation_scale_ub = torch.tensor(
|
|
||||||
[fp8_activation_scale_ub],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=output_device,
|
|
||||||
)
|
|
||||||
return Int4Weights(
|
return Int4Weights(
|
||||||
weight=w.to(torch.int8).to(device=output_device),
|
weight=w.to(torch.int8).to(device=output_device),
|
||||||
scale=scale.to(device=output_device),
|
scale=scale.to(device=output_device),
|
||||||
zero_point=zero_point.to(device=output_device),
|
zero_point=zero_point.to(device=output_device),
|
||||||
shape=w.shape,
|
shape=w.shape,
|
||||||
activation_scale_ub=activation_scale_ub,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue