updating license for toolchain

This commit is contained in:
Ashwin Bharambe 2024-07-22 20:31:42 -07:00
parent 0e2fc9966a
commit 86fff23a9e
74 changed files with 512 additions and 94 deletions

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -5,7 +11,7 @@ import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from torch import Tensor
@ -26,29 +32,25 @@ class FP8Tests(unittest.TestCase):
)
def test_fp8_ffn(
self,
D: int,
D: int, # noqa
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w3 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape
(HD_L, D_) = w1.shape
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
x1 = x.view(B * T, D) @ w1.T