mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -2,7 +2,6 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
from enum import Enum, unique
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
|
@ -11,20 +10,12 @@ try:
|
|||
print("Using efficient FP8 operators in FBGEMM.")
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
@unique
|
||||
class FfnQuantizeMode(Enum):
|
||||
FP8_ROWWISE = "fp8_rowwise"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
|
@ -84,7 +75,6 @@ def ffn_swiglu(
|
|||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
mode: Optional[FfnQuantizeMode] = None,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
@ -92,22 +82,45 @@ def quantize_fp8(
|
|||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
mode (FfnQuantizeMode): Quantization mode.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||
scale=w_scale.to(device="cuda"),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_fp8_dynamic(
|
||||
|
|
106
toolchain/inference/quantization/loader.py
Normal file
106
toolchain/inference/quantization/loader.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from models.llama3_1.api.model import Transformer, TransformerBlock
|
||||
|
||||
from toolchain.inference.api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
InlineImplConfig,
|
||||
)
|
||||
from toolchain.inference.api.datatypes import (
|
||||
QuantizationType,
|
||||
)
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
def is_fbgemm_available() -> bool:
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
return True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: InlineImplConfig,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
return model
|
||||
|
||||
elif config.quantization.type != QuantizationType.fp8.value:
|
||||
raise ValueError("Only FP8 quantization is supported")
|
||||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
|
||||
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
# Move weights to GPU with quantization
|
||||
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
cprint("Loading fp8 scales...", "yellow")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
fp8_scales_path
|
||||
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.w1.weight = load_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
block.feed_forward.w3.weight = load_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
block.feed_forward.w2.weight = load_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.w1.weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
block.feed_forward.w3.weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
block.feed_forward.w2.weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
return model
|
|
@ -18,6 +18,12 @@ from fp8.fp8_impls import ffn_swiglu
|
|||
from torch import nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationArgs:
|
||||
fp8_rowwise: bool = False
|
||||
convert_from_bf16: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 4096
|
||||
|
@ -31,6 +37,8 @@ class ModelArgs:
|
|||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
|
||||
quantization: Optional[QuantizationArgs] = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -33,70 +33,42 @@ class FP8Tests(unittest.TestCase):
|
|||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w13 = (
|
||||
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w1 = (
|
||||
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
)
|
||||
w3 = (
|
||||
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
)
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB)
|
||||
w13_q = quantize_fp8(w13, UB)
|
||||
w2_q = quantize_fp8(w2, UB)
|
||||
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
|
||||
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
|
||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape
|
||||
(HD_L_2, D_) = w13.shape
|
||||
(HD_L, D_) = w1.shape
|
||||
assert D_ == D
|
||||
HD_L = HD_L_2 // 2
|
||||
|
||||
y = x.view(B * T, D) @ w13.T
|
||||
x1 = y[:, :HD_L]
|
||||
x2 = y[:, HD_L:]
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||
|
||||
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
|
||||
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale
|
||||
w13 = w13_q.weight.bfloat16() * w13_q.scale
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale
|
||||
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
|
||||
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
|
||||
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
|
||||
|
||||
v_ref = ref_ffn(x, w13, w2)
|
||||
v_ref = ref_ffn(x, w1, w3, w2)
|
||||
|
||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
B_T=st.sampled_from([2048, 4096]),
|
||||
D=st.sampled_from([128, 256]),
|
||||
HD_L=st.sampled_from([256, 512]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
|
||||
B_T = 4096
|
||||
D = 256
|
||||
HD_L = 512
|
||||
UB = float(UB)
|
||||
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
|
||||
x_q = quantize_fp8(x, UB)
|
||||
wqkv_q = quantize_fp8(wqkv, UB)
|
||||
|
||||
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
|
||||
|
||||
y = attn_linear(x, wqkv_q)
|
||||
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale
|
||||
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
|
||||
y_ref = (x @ wqkv.T).to(torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
|
||||
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue