make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -5,7 +5,7 @@ import unittest
import torch
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
from hypothesis import given, settings, strategies as st
from torch import Tensor
@ -33,70 +33,42 @@ class FP8Tests(unittest.TestCase):
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w13 = (
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w1 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w3 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB)
w13_q = quantize_fp8(w13, UB)
w2_q = quantize_fp8(w2, UB)
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape
(HD_L_2, D_) = w13.shape
(HD_L, D_) = w1.shape
assert D_ == D
HD_L = HD_L_2 // 2
y = x.view(B * T, D) @ w13.T
x1 = y[:, :HD_L]
x2 = y[:, HD_L:]
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
w13 = w13_q.weight.bfloat16() * w13_q.scale
w2 = w2_q.weight.bfloat16() * w2_q.scale
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w13, w2)
v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
@settings(deadline=None)
@given(
B_T=st.sampled_from([2048, 4096]),
D=st.sampled_from([128, 256]),
HD_L=st.sampled_from([256, 512]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
B_T = 4096
D = 256
HD_L = 512
UB = float(UB)
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
x_q = quantize_fp8(x, UB)
wqkv_q = quantize_fp8(wqkv, UB)
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
y = attn_linear(x, wqkv_q)
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
y_ref = (x @ wqkv.T).to(torch.bfloat16)
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
if __name__ == "__main__":
unittest.main()