mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
Add toolchain from agentic system here
This commit is contained in:
parent
f6b2b2fb39
commit
95781ec85d
71 changed files with 11899 additions and 0 deletions
102
toolchain/inference/quantization/test_fp8.py
Normal file
102
toolchain/inference/quantization/test_fp8.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
"Skip when H100 is not available",
|
||||
)
|
||||
class FP8Tests(unittest.TestCase):
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
D=st.sampled_from([4096, 8192]),
|
||||
HD_L=st.sampled_from([1280, 2560]),
|
||||
B=st.sampled_from([1, 2]),
|
||||
T=st.sampled_from([2048, 4096]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_ffn(
|
||||
self,
|
||||
D: int,
|
||||
HD_L: int,
|
||||
B: int,
|
||||
T: int,
|
||||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w13 = (
|
||||
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
)
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB)
|
||||
w13_q = quantize_fp8(w13, UB)
|
||||
w2_q = quantize_fp8(w2, UB)
|
||||
|
||||
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape
|
||||
(HD_L_2, D_) = w13.shape
|
||||
assert D_ == D
|
||||
HD_L = HD_L_2 // 2
|
||||
|
||||
y = x.view(B * T, D) @ w13.T
|
||||
x1 = y[:, :HD_L]
|
||||
x2 = y[:, HD_L:]
|
||||
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||
|
||||
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale
|
||||
w13 = w13_q.weight.bfloat16() * w13_q.scale
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale
|
||||
|
||||
v_ref = ref_ffn(x, w13, w2)
|
||||
|
||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
B_T=st.sampled_from([2048, 4096]),
|
||||
D=st.sampled_from([128, 256]),
|
||||
HD_L=st.sampled_from([256, 512]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
|
||||
B_T = 4096
|
||||
D = 256
|
||||
HD_L = 512
|
||||
UB = float(UB)
|
||||
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
|
||||
x_q = quantize_fp8(x, UB)
|
||||
wqkv_q = quantize_fp8(wqkv, UB)
|
||||
|
||||
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
|
||||
|
||||
y = attn_linear(x, wqkv_q)
|
||||
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale
|
||||
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
|
||||
y_ref = (x @ wqkv.T).to(torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
|
||||
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue