diff --git a/toolchain/configs/ashwin.yaml b/toolchain/configs/ashwin.yaml index c2f1ca245..80e07df96 100644 --- a/toolchain/configs/ashwin.yaml +++ b/toolchain/configs/ashwin.yaml @@ -7,3 +7,5 @@ model_inference_config: model_parallel_size: 1 max_seq_len: 2048 max_batch_size: 1 + quantization: + type: "fp8" diff --git a/toolchain/inference/api/config.py b/toolchain/inference/api/config.py index 2340e2d32..4a6c5145f 100644 --- a/toolchain/inference/api/config.py +++ b/toolchain/inference/api/config.py @@ -7,14 +7,7 @@ from hydra.core.config_store import ConfigStore from pydantic import BaseModel, Field from typing_extensions import Annotated - -@dataclass -class GeneratorArgs: - ckpt_dir: str - tokenizer_path: str - model_parallel_size: Optional[int] = None - max_seq_len: int = 2048 - max_batch_size: int = 4 +from .datatypes import QuantizationConfig class ImplType(Enum): @@ -27,6 +20,17 @@ class CheckpointType(Enum): huggingface = "huggingface" +# This enum represents the format in which weights are specified +# This does not necessarily always equal what quantization is desired +# at runtime since there can be on-the-fly conversions done +class CheckpointQuantizationFormat(Enum): + # default format + bf16 = "bf16" + + # used for enabling fp8_rowwise inference, some weights are bf16 + fp8_mixed = "fp8_mixed" + + class PytorchCheckpoint(BaseModel): checkpoint_type: Literal[CheckpointType.pytorch.value] = ( CheckpointType.pytorch.value @@ -34,6 +38,9 @@ class PytorchCheckpoint(BaseModel): checkpoint_dir: str tokenizer_path: str model_parallel_size: int + quantization_format: CheckpointQuantizationFormat = ( + CheckpointQuantizationFormat.bf16 + ) class HuggingFaceCheckpoint(BaseModel): @@ -42,6 +49,9 @@ class HuggingFaceCheckpoint(BaseModel): ) repo_id: str # or model_name ? model_parallel_size: int + quantization_format: CheckpointQuantizationFormat = ( + CheckpointQuantizationFormat.bf16 + ) class ModelCheckpointConfig(BaseModel): @@ -51,10 +61,11 @@ class ModelCheckpointConfig(BaseModel): ] -# NOTE: this same config will be used when instantiating an inference server naturally class InlineImplConfig(BaseModel): impl_type: Literal[ImplType.inline.value] = ImplType.inline.value checkpoint_config: ModelCheckpointConfig + quantization: Optional[QuantizationConfig] = None + torch_seed: Optional[int] = None max_seq_len: int max_batch_size: int = 1 @@ -86,6 +97,7 @@ class InlineImplHydraConfig: model_parallel_size: int max_seq_len: int max_batch_size: int = 1 + quantization: Optional[QuantizationConfig] = None # TODO: huggingface checkpoint required args def convert_to_inline_impl_config(self): @@ -99,6 +111,7 @@ class InlineImplHydraConfig: model_parallel_size=self.model_parallel_size, ) ), + quantization=self.quantization, max_seq_len=self.max_seq_len, max_batch_size=self.max_batch_size, ) diff --git a/toolchain/inference/api/datatypes.py b/toolchain/inference/api/datatypes.py index cd7c8a432..90b9dfe73 100644 --- a/toolchain/inference/api/datatypes.py +++ b/toolchain/inference/api/datatypes.py @@ -21,19 +21,19 @@ class QuantizationType(Enum): @json_schema_type class Fp8QuantizationConfig(BaseModel): - quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value + type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value @json_schema_type class Bf16QuantizationConfig(BaseModel): - quantization_type: Literal[QuantizationType.bf16.value] = ( + type: Literal[QuantizationType.bf16.value] = ( QuantizationType.bf16.value ) QuantizationConfig = Annotated[ Union[Bf16QuantizationConfig, Fp8QuantizationConfig], - Field(discriminator="quantization_type"), + Field(discriminator="type"), ] diff --git a/toolchain/inference/generation.py b/toolchain/inference/generation.py index ecca76572..f714760ec 100644 --- a/toolchain/inference/generation.py +++ b/toolchain/inference/generation.py @@ -7,7 +7,7 @@ import sys import time from dataclasses import dataclass from pathlib import Path -from typing import Generator, List, Optional, TypedDict +from typing import Generator, List, Optional import torch import torch.nn.functional as F @@ -23,6 +23,9 @@ from models.llama3_1.api.model import Transformer from models.llama3_1.api.tokenizer import Tokenizer from termcolor import cprint +from .api.config import CheckpointType, InlineImplConfig +from .api.datatypes import QuantizationType + @dataclass class TokenResult: @@ -31,69 +34,52 @@ class TokenResult: logprobs: Optional[List[float]] = None -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[str] # not required - logprobs: List[float] # not required - - class Llama: @staticmethod - def build( - ckpt_dir: str, - tokenizer_path: str, - max_seq_len: int, - max_batch_size: int, - model_parallel_size: Optional[int] = None, - seed: int = 1, - ) -> "Llama": + def build(config: InlineImplConfig): """ Build a Llama instance by initializing and loading a model checkpoint. - Args: - ckpt_dir (str): Path to the directory containing checkpoint files. - tokenizer_path (str): Path to the tokenizer file. - max_seq_len (int): Maximum sequence length for input text. - max_batch_size (int): Maximum batch size for inference. - model_parallel_size (Optional[int], optional): Number of model parallel processes. - If not provided, it's determined from the environment. Defaults to None. - - Returns: - Llama: An instance of the Llama class with the loaded model and tokenizer. - - Raises: - AssertionError: If there are no checkpoint files in the specified directory, - or if the model parallel size does not match the number of checkpoint files. - Note: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ + checkpoint = config.checkpoint_config.checkpoint + if checkpoint.checkpoint_type != CheckpointType.pytorch.value: + raise NotImplementedError("HuggingFace checkpoints not supported yet") + + if config.quantization and config.quantization.type == QuantizationType.fp8.value: + from .quantization.loader import is_fbgemm_available + + if not is_fbgemm_available(): + raise ImportError("fbgemm-gpu is required for FP8 quantization") if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") + + model_parallel_size = checkpoint.model_parallel_size if not model_parallel_is_initialized(): - if model_parallel_size is None: - model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) # seed must be the same in all processes - torch.manual_seed(seed) + if config.torch_seed is not None: + torch.manual_seed(config.torch_seed) if local_rank > 0: sys.stdout = open(os.devnull, "w") start_time = time.time() + ckpt_dir = checkpoint.checkpoint_dir checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert model_parallel_size == len( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ckpt_path = checkpoints[get_model_parallel_rank()] - checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -103,22 +89,34 @@ class Llama: params = params["model"] model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, + max_seq_len=config.max_seq_len, + max_batch_size=config.max_batch_size, **params, ) - tokenizer = Tokenizer(model_path=tokenizer_path) + tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path) assert ( model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" + + # load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) + + model = Transformer(model_args) + model.load_state_dict(state_dict, strict=False) + if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.HalfTensor) - model = Transformer(model_args) - model.load_state_dict(checkpoint, strict=False) + if config.quantization: + from .quantization.loader import convert_to_quantized_model + + model = convert_to_quantized_model(model, config) + else: + model = model.to("cuda") + print(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, model_args) diff --git a/toolchain/inference/inference.py b/toolchain/inference/inference.py index 5a117eb09..5ec1c897d 100644 --- a/toolchain/inference/inference.py +++ b/toolchain/inference/inference.py @@ -2,10 +2,15 @@ from typing import AsyncGenerator from models.llama3_1.api.datatypes import StopReason -from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig +from .api.config import ( + CheckpointQuantizationFormat, + CheckpointType, + InlineImplConfig, +) from .api.datatypes import ( ChatCompletionResponseEvent, ChatCompletionResponseEventType, + QuantizationConfig, ToolCallDelta, ToolCallParseStatus, ) @@ -18,33 +23,13 @@ from .api.endpoints import ( from .model_parallel import LlamaModelParallelGenerator -def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs: - if ( - config.checkpoint_config.checkpoint.checkpoint_type - == CheckpointType.pytorch.value - ): - pt_checkpoint = config.checkpoint_config.checkpoint - return GeneratorArgs( - ckpt_dir=pt_checkpoint.checkpoint_dir, - tokenizer_path=pt_checkpoint.tokenizer_path, - model_parallel_size=pt_checkpoint.model_parallel_size, - max_seq_len=config.max_seq_len, - max_batch_size=config.max_batch_size, - ) - else: - raise NotImplementedError("HF Checkpoint not supported yet") - - class ModelInferenceImpl(ModelInference): def __init__(self, config: InlineImplConfig) -> None: self.config = config async def initialize(self) -> None: - generator_args = generator_args_from_config(self.config) - self.generator = LlamaModelParallelGenerator( - args=generator_args, - ) + self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() async def shutdown(self) -> None: diff --git a/toolchain/inference/model_parallel.py b/toolchain/inference/model_parallel.py index 2a7fcf781..2ffbe2fb0 100644 --- a/toolchain/inference/model_parallel.py +++ b/toolchain/inference/model_parallel.py @@ -6,7 +6,7 @@ from models.llama3_1.api.chat_format import ChatFormat from models.llama3_1.api.datatypes import Message from models.llama3_1.api.tokenizer import Tokenizer -from .api.config import GeneratorArgs +from .api.config import InlineImplConfig from .generation import Llama from .parallel_utils import ModelParallelProcessGroup @@ -35,13 +35,8 @@ class ModelRunner: ) -def init_model_cb(args: GeneratorArgs): - llama = Llama.build( - args.ckpt_dir, - args.tokenizer_path, - args.max_seq_len, - args.max_batch_size, - ) +def init_model_cb(config: InlineImplConfig): + llama = Llama.build(config) return ModelRunner(llama) @@ -56,12 +51,13 @@ class LlamaModelParallelGenerator: clear at the callsite why we need to use a context manager. """ - def __init__(self, args: GeneratorArgs): - self.args = args + def __init__(self, config: InlineImplConfig): + self.config = config # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going - self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path)) + checkpoint = self.config.checkpoint_config.checkpoint + self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path)) def start(self): self.__enter__() @@ -70,9 +66,10 @@ class LlamaModelParallelGenerator: self.__exit__(None, None, None) def __enter__(self): + checkpoint = self.config.checkpoint_config.checkpoint self.group = ModelParallelProcessGroup( - self.args.model_parallel_size, - init_model_cb=partial(init_model_cb, self.args), + checkpoint.model_parallel_size, + init_model_cb=partial(init_model_cb, self.config), ) self.group.start() return self diff --git a/toolchain/inference/quantization/fp8_impls.py b/toolchain/inference/quantization/fp8_impls.py index 095039b24..9cac8bea0 100644 --- a/toolchain/inference/quantization/fp8_impls.py +++ b/toolchain/inference/quantization/fp8_impls.py @@ -2,7 +2,6 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import collections -from enum import Enum, unique from typing import Optional, Type try: @@ -11,20 +10,12 @@ try: print("Using efficient FP8 operators in FBGEMM.") except (ImportError, ModuleNotFoundError): print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") + raise import torch from torch import nn, Tensor -@unique -class FfnQuantizeMode(Enum): - FP8_ROWWISE = "fp8_rowwise" - NONE = "none" - - def __str__(self) -> str: - return self.value - - class Fp8ScaledWeights: # TODO: Ugly trick so torch allows us to replace parameters # with our custom Fp8Weights instance. Do this properly. @@ -84,7 +75,6 @@ def ffn_swiglu( def quantize_fp8( w: Tensor, fp8_activation_scale_ub: float, - mode: Optional[FfnQuantizeMode] = None, output_device: Optional[torch.device] = None, ) -> Fp8RowwiseWeights: """Quantize [n, k] weight tensor. @@ -92,22 +82,45 @@ def quantize_fp8( Args: w (Tensor): [n, k] input high precision tensor to quantize. fp8_activation_scale_ub (float): Upper bound for activation max. - mode (FfnQuantizeMode): Quantization mode. """ activation_scale_ub = torch.tensor( [fp8_activation_scale_ub], dtype=torch.float, device="cuda", ) - if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - del w - return Fp8RowwiseWeights( - weight=wq, - scale=w_scale, - shape=wq.shape, - activation_scale_ub=activation_scale_ub, - ) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + del w + return Fp8RowwiseWeights( + weight=wq, + scale=w_scale, + shape=wq.shape, + activation_scale_ub=activation_scale_ub, + ) + + +@torch.inference_mode() +def load_fp8( + w: Tensor, + w_scale: Tensor, + fp8_activation_scale_ub: float, +) -> Fp8RowwiseWeights: + """Load FP8 [n, k] weight tensor. + + Args: + w (Tensor): [n, k] input FP8. + fp8_activation_scale_ub (float): Upper bound for activation max. + """ + activation_scale_ub = torch.tensor( + [fp8_activation_scale_ub], + dtype=torch.float, + device="cuda", + ) + return Fp8RowwiseWeights( + weight=w.to(torch.float8_e4m3fn).to(device="cuda"), + scale=w_scale.to(device="cuda"), + shape=w.shape, + activation_scale_ub=activation_scale_ub, + ) def fc_fp8_dynamic( diff --git a/toolchain/inference/quantization/loader.py b/toolchain/inference/quantization/loader.py new file mode 100644 index 000000000..fde77685f --- /dev/null +++ b/toolchain/inference/quantization/loader.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os + +from typing import Optional + +import torch +from models.llama3_1.api.model import Transformer, TransformerBlock + +from toolchain.inference.api.config import ( + CheckpointQuantizationFormat, + InlineImplConfig, +) +from toolchain.inference.api.datatypes import ( + QuantizationType, +) + +from termcolor import cprint + +def is_fbgemm_available() -> bool: + try: + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + return True + except (ImportError, ModuleNotFoundError): + return False + + +def convert_to_quantized_model( + model: Transformer, + config: InlineImplConfig, + fp8_activation_scale_ub: Optional[float] = 1200.0, +) -> Transformer: + if config.quantization.type == QuantizationType.bf16.value: + return model + + elif config.quantization.type != QuantizationType.fp8.value: + raise ValueError("Only FP8 quantization is supported") + + from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8 + + + + checkpoint = config.checkpoint_config.checkpoint + # Move weights to GPU with quantization + if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: + cprint("Loading fp8 scales...", "yellow") + fp8_scales_path = os.path.join( + checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" + ) + assert os.path.isfile( + fp8_scales_path + ), f"fp8_scales_path not found for rank {get_model_parallel_rank()}" + fp8_scales = torch.load(fp8_scales_path, weights_only=True) + + for block in model.layers: + if isinstance(block, TransformerBlock): + if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): + continue + block.feed_forward.w1.weight = load_fp8( + block.feed_forward.w1.weight, + fp8_scales[ + f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}" + ], + fp8_activation_scale_ub, + ) + block.feed_forward.w3.weight = load_fp8( + block.feed_forward.w3.weight, + fp8_scales[ + f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}" + ], + fp8_activation_scale_ub, + ) + block.feed_forward.w2.weight = load_fp8( + block.feed_forward.w2.weight, + fp8_scales[ + f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}" + ], + fp8_activation_scale_ub, + ) + else: + cprint("Quantizing fp8 weights from bf16...", "yellow") + for block in model.layers: + if isinstance(block, TransformerBlock): + if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): + continue + block.feed_forward.w1.weight = quantize_fp8( + block.feed_forward.w1.weight, + fp8_activation_scale_ub, + output_device=torch.device("cuda"), + ) + block.feed_forward.w3.weight = quantize_fp8( + block.feed_forward.w3.weight, + fp8_activation_scale_ub, + output_device=torch.device("cuda"), + ) + block.feed_forward.w2.weight = quantize_fp8( + block.feed_forward.w2.weight, + fp8_activation_scale_ub, + output_device=torch.device("cuda"), + ) + + for _, parameter in model.named_parameters(): + if not isinstance(parameter, Fp8ScaledWeights): + parameter.data = parameter.to(device="cuda") + return model diff --git a/toolchain/inference/quantization/model.py b/toolchain/inference/quantization/model.py index ce806e697..44500d494 100644 --- a/toolchain/inference/quantization/model.py +++ b/toolchain/inference/quantization/model.py @@ -18,6 +18,12 @@ from fp8.fp8_impls import ffn_swiglu from torch import nn +@dataclass +class QuantizationArgs: + fp8_rowwise: bool = False + convert_from_bf16: bool = False + + @dataclass class ModelArgs: dim: int = 4096 @@ -31,6 +37,8 @@ class ModelArgs: rope_theta: float = 500000 use_scaled_rope: bool = False + quantization: Optional[QuantizationArgs] = None + max_batch_size: int = 32 max_seq_len: int = 2048 diff --git a/toolchain/inference/quantization/test_fp8.py b/toolchain/inference/quantization/test_fp8.py index 3e6f75213..27b95f65c 100644 --- a/toolchain/inference/quantization/test_fp8.py +++ b/toolchain/inference/quantization/test_fp8.py @@ -5,7 +5,7 @@ import unittest import torch -from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8 +from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode from hypothesis import given, settings, strategies as st from torch import Tensor @@ -33,70 +33,42 @@ class FP8Tests(unittest.TestCase): UB: float, ) -> None: x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1 - w13 = ( - torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 + w1 = ( + torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 + ) + w3 = ( + torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 ) w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1 - x_q = quantize_fp8(x, UB) - w13_q = quantize_fp8(w13, UB) - w2_q = quantize_fp8(w2, UB) + x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE) + w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE) + w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE) + w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE) - def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor: + def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor: (B, T, D) = x.shape - (HD_L_2, D_) = w13.shape + (HD_L, D_) = w1.shape assert D_ == D - HD_L = HD_L_2 // 2 - y = x.view(B * T, D) @ w13.T - x1 = y[:, :HD_L] - x2 = y[:, HD_L:] + x1 = x.view(B * T, D) @ w1.T + x2 = x.view(B * T, D) @ w3.T z = torch.nn.functional.silu(x1) * x2 return (z @ w2.T).view(B, T, D).to(torch.bfloat16) - v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q) + v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q) # Fake quant - x = x_q.weight.bfloat16() * x_q.scale - w13 = w13_q.weight.bfloat16() * w13_q.scale - w2 = w2_q.weight.bfloat16() * w2_q.scale + x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1) + w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1) + w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1) + w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1) - v_ref = ref_ffn(x, w13, w2) + v_ref = ref_ffn(x, w1, w3, w2) torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3) - @settings(deadline=None) - @given( - B_T=st.sampled_from([2048, 4096]), - D=st.sampled_from([128, 256]), - HD_L=st.sampled_from([256, 512]), - UB=st.sampled_from([1000, 10000]), - ) - def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None: - B_T = 4096 - D = 256 - HD_L = 512 - UB = float(UB) - x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1 - wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 - - x_q = quantize_fp8(x, UB) - wqkv_q = quantize_fp8(wqkv, UB) - - num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda") - - y = attn_linear(x, wqkv_q) - y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens) - - # Fake quant - x = x_q.weight.bfloat16() * x_q.scale - wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale - y_ref = (x @ wqkv.T).to(torch.bfloat16) - - torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3) - torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3) - if __name__ == "__main__": unittest.main()