make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -7,3 +7,5 @@ model_inference_config:
model_parallel_size: 1 model_parallel_size: 1
max_seq_len: 2048 max_seq_len: 2048
max_batch_size: 1 max_batch_size: 1
quantization:
type: "fp8"

View file

@ -7,14 +7,7 @@ from hydra.core.config_store import ConfigStore
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from .datatypes import QuantizationConfig
@dataclass
class GeneratorArgs:
ckpt_dir: str
tokenizer_path: str
model_parallel_size: Optional[int] = None
max_seq_len: int = 2048
max_batch_size: int = 4
class ImplType(Enum): class ImplType(Enum):
@ -27,6 +20,17 @@ class CheckpointType(Enum):
huggingface = "huggingface" huggingface = "huggingface"
# This enum represents the format in which weights are specified
# This does not necessarily always equal what quantization is desired
# at runtime since there can be on-the-fly conversions done
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8_mixed"
class PytorchCheckpoint(BaseModel): class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = ( checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value CheckpointType.pytorch.value
@ -34,6 +38,9 @@ class PytorchCheckpoint(BaseModel):
checkpoint_dir: str checkpoint_dir: str
tokenizer_path: str tokenizer_path: str
model_parallel_size: int model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
class HuggingFaceCheckpoint(BaseModel): class HuggingFaceCheckpoint(BaseModel):
@ -42,6 +49,9 @@ class HuggingFaceCheckpoint(BaseModel):
) )
repo_id: str # or model_name ? repo_id: str # or model_name ?
model_parallel_size: int model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
class ModelCheckpointConfig(BaseModel): class ModelCheckpointConfig(BaseModel):
@ -51,10 +61,11 @@ class ModelCheckpointConfig(BaseModel):
] ]
# NOTE: this same config will be used when instantiating an inference server naturally
class InlineImplConfig(BaseModel): class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
checkpoint_config: ModelCheckpointConfig checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int max_seq_len: int
max_batch_size: int = 1 max_batch_size: int = 1
@ -86,6 +97,7 @@ class InlineImplHydraConfig:
model_parallel_size: int model_parallel_size: int
max_seq_len: int max_seq_len: int
max_batch_size: int = 1 max_batch_size: int = 1
quantization: Optional[QuantizationConfig] = None
# TODO: huggingface checkpoint required args # TODO: huggingface checkpoint required args
def convert_to_inline_impl_config(self): def convert_to_inline_impl_config(self):
@ -99,6 +111,7 @@ class InlineImplHydraConfig:
model_parallel_size=self.model_parallel_size, model_parallel_size=self.model_parallel_size,
) )
), ),
quantization=self.quantization,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
) )

View file

@ -21,19 +21,19 @@ class QuantizationType(Enum):
@json_schema_type @json_schema_type
class Fp8QuantizationConfig(BaseModel): class Fp8QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type @json_schema_type
class Bf16QuantizationConfig(BaseModel): class Bf16QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.bf16.value] = ( type: Literal[QuantizationType.bf16.value] = (
QuantizationType.bf16.value QuantizationType.bf16.value
) )
QuantizationConfig = Annotated[ QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig], Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="quantization_type"), Field(discriminator="type"),
] ]

View file

@ -7,7 +7,7 @@ import sys
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generator, List, Optional, TypedDict from typing import Generator, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -23,6 +23,9 @@ from models.llama3_1.api.model import Transformer
from models.llama3_1.api.tokenizer import Tokenizer from models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig
from .api.datatypes import QuantizationType
@dataclass @dataclass
class TokenResult: class TokenResult:
@ -31,69 +34,52 @@ class TokenResult:
logprobs: Optional[List[float]] = None logprobs: Optional[List[float]] = None
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class Llama: class Llama:
@staticmethod @staticmethod
def build( def build(config: InlineImplConfig):
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
seed: int = 1,
) -> "Llama":
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
Args:
ckpt_dir (str): Path to the directory containing checkpoint files.
tokenizer_path (str): Path to the tokenizer file.
max_seq_len (int): Maximum sequence length for input text.
max_batch_size (int): Maximum batch size for inference.
model_parallel_size (Optional[int], optional): Number of model parallel processes.
If not provided, it's determined from the environment. Defaults to None.
Returns:
Llama: An instance of the Llama class with the loaded model and tokenizer.
Raises:
AssertionError: If there are no checkpoint files in the specified directory,
or if the model parallel size does not match the number of checkpoint files.
Note: Note:
This method initializes the distributed process group, sets the device to CUDA, This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
checkpoint = config.checkpoint_config.checkpoint
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size) initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
# seed must be the same in all processes # seed must be the same in all processes
torch.manual_seed(seed) if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0: if local_rank > 0:
sys.stdout = open(os.devnull, "w") sys.stdout = open(os.devnull, "w")
start_time = time.time() start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len( assert model_parallel_size == len(
checkpoints checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()] ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read()) params = json.loads(f.read())
@ -103,22 +89,34 @@ class Llama:
params = params["model"] params = params["model"]
model_args: ModelArgs = ModelArgs( model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_seq_len=config.max_seq_len,
max_batch_size=max_batch_size, max_batch_size=config.max_batch_size,
**params, **params,
) )
tokenizer = Tokenizer(model_path=tokenizer_path) tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
assert ( assert (
model_args.vocab_size == tokenizer.n_words model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else: else:
torch.set_default_tensor_type(torch.cuda.HalfTensor) torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args) if config.quantization:
model.load_state_dict(checkpoint, strict=False) from .quantization.loader import convert_to_quantized_model
model = convert_to_quantized_model(model, config)
else:
model = model.to("cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args) return Llama(model, tokenizer, model_args)

View file

@ -2,10 +2,15 @@ from typing import AsyncGenerator
from models.llama3_1.api.datatypes import StopReason from models.llama3_1.api.datatypes import StopReason
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig from .api.config import (
CheckpointQuantizationFormat,
CheckpointType,
InlineImplConfig,
)
from .api.datatypes import ( from .api.datatypes import (
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
QuantizationConfig,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
@ -18,33 +23,13 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
if (
config.checkpoint_config.checkpoint.checkpoint_type
== CheckpointType.pytorch.value
):
pt_checkpoint = config.checkpoint_config.checkpoint
return GeneratorArgs(
ckpt_dir=pt_checkpoint.checkpoint_dir,
tokenizer_path=pt_checkpoint.tokenizer_path,
model_parallel_size=pt_checkpoint.model_parallel_size,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
class ModelInferenceImpl(ModelInference): class ModelInferenceImpl(ModelInference):
def __init__(self, config: InlineImplConfig) -> None: def __init__(self, config: InlineImplConfig) -> None:
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
generator_args = generator_args_from_config(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator = LlamaModelParallelGenerator(
args=generator_args,
)
self.generator.start() self.generator.start()
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -6,7 +6,7 @@ from models.llama3_1.api.chat_format import ChatFormat
from models.llama3_1.api.datatypes import Message from models.llama3_1.api.datatypes import Message
from models.llama3_1.api.tokenizer import Tokenizer from models.llama3_1.api.tokenizer import Tokenizer
from .api.config import GeneratorArgs from .api.config import InlineImplConfig
from .generation import Llama from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup from .parallel_utils import ModelParallelProcessGroup
@ -35,13 +35,8 @@ class ModelRunner:
) )
def init_model_cb(args: GeneratorArgs): def init_model_cb(config: InlineImplConfig):
llama = Llama.build( llama = Llama.build(config)
args.ckpt_dir,
args.tokenizer_path,
args.max_seq_len,
args.max_batch_size,
)
return ModelRunner(llama) return ModelRunner(llama)
@ -56,12 +51,13 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager. clear at the callsite why we need to use a context manager.
""" """
def __init__(self, args: GeneratorArgs): def __init__(self, config: InlineImplConfig):
self.args = args self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going # while the tool-use loop is going
self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path)) checkpoint = self.config.checkpoint_config.checkpoint
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
def start(self): def start(self):
self.__enter__() self.__enter__()
@ -70,9 +66,10 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None) self.__exit__(None, None, None)
def __enter__(self): def __enter__(self):
checkpoint = self.config.checkpoint_config.checkpoint
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
self.args.model_parallel_size, checkpoint.model_parallel_size,
init_model_cb=partial(init_model_cb, self.args), init_model_cb=partial(init_model_cb, self.config),
) )
self.group.start() self.group.start()
return self return self

View file

@ -2,7 +2,6 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections import collections
from enum import Enum, unique
from typing import Optional, Type from typing import Optional, Type
try: try:
@ -11,20 +10,12 @@ try:
print("Using efficient FP8 operators in FBGEMM.") print("Using efficient FP8 operators in FBGEMM.")
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
@unique
class FfnQuantizeMode(Enum):
FP8_ROWWISE = "fp8_rowwise"
NONE = "none"
def __str__(self) -> str:
return self.value
class Fp8ScaledWeights: class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters # TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly. # with our custom Fp8Weights instance. Do this properly.
@ -84,7 +75,6 @@ def ffn_swiglu(
def quantize_fp8( def quantize_fp8(
w: Tensor, w: Tensor,
fp8_activation_scale_ub: float, fp8_activation_scale_ub: float,
mode: Optional[FfnQuantizeMode] = None,
output_device: Optional[torch.device] = None, output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights: ) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor. """Quantize [n, k] weight tensor.
@ -92,22 +82,45 @@ def quantize_fp8(
Args: Args:
w (Tensor): [n, k] input high precision tensor to quantize. w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max. fp8_activation_scale_ub (float): Upper bound for activation max.
mode (FfnQuantizeMode): Quantization mode.
""" """
activation_scale_ub = torch.tensor( activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub], [fp8_activation_scale_ub],
dtype=torch.float, dtype=torch.float,
device="cuda", device="cuda",
) )
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) del w
del w return Fp8RowwiseWeights(
return Fp8RowwiseWeights( weight=wq,
weight=wq, scale=w_scale,
scale=w_scale, shape=wq.shape,
shape=wq.shape, activation_scale_ub=activation_scale_ub,
activation_scale_ub=activation_scale_ub, )
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic( def fc_fp8_dynamic(

View file

@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from models.llama3_1.api.model import Transformer, TransformerBlock
from toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from toolchain.inference.api.datatypes import (
QuantizationType,
)
from termcolor import cprint
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except (ImportError, ModuleNotFoundError):
return False
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w3.weight = load_fp8(
block.feed_forward.w3.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w2.weight = load_fp8(
block.feed_forward.w2.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w3.weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w2.weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model

View file

@ -18,6 +18,12 @@ from fp8.fp8_impls import ffn_swiglu
from torch import nn from torch import nn
@dataclass
class QuantizationArgs:
fp8_rowwise: bool = False
convert_from_bf16: bool = False
@dataclass @dataclass
class ModelArgs: class ModelArgs:
dim: int = 4096 dim: int = 4096
@ -31,6 +37,8 @@ class ModelArgs:
rope_theta: float = 500000 rope_theta: float = 500000
use_scaled_rope: bool = False use_scaled_rope: bool = False
quantization: Optional[QuantizationArgs] = None
max_batch_size: int = 32 max_batch_size: int = 32
max_seq_len: int = 2048 max_seq_len: int = 2048

View file

@ -5,7 +5,7 @@ import unittest
import torch import torch
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8 from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
from hypothesis import given, settings, strategies as st from hypothesis import given, settings, strategies as st
from torch import Tensor from torch import Tensor
@ -33,70 +33,42 @@ class FP8Tests(unittest.TestCase):
UB: float, UB: float,
) -> None: ) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1 x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w13 = ( w1 = (
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w3 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
) )
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1 w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB) x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w13_q = quantize_fp8(w13, UB) w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB) w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor: def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape (B, T, D) = x.shape
(HD_L_2, D_) = w13.shape (HD_L, D_) = w1.shape
assert D_ == D assert D_ == D
HD_L = HD_L_2 // 2
y = x.view(B * T, D) @ w13.T x1 = x.view(B * T, D) @ w1.T
x1 = y[:, :HD_L] x2 = x.view(B * T, D) @ w3.T
x2 = y[:, HD_L:]
z = torch.nn.functional.silu(x1) * x2 z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16) return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q) v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant # Fake quant
x = x_q.weight.bfloat16() * x_q.scale x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w13 = w13_q.weight.bfloat16() * w13_q.scale w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w13, w2) v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3) torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
@settings(deadline=None)
@given(
B_T=st.sampled_from([2048, 4096]),
D=st.sampled_from([128, 256]),
HD_L=st.sampled_from([256, 512]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
B_T = 4096
D = 256
HD_L = 512
UB = float(UB)
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
x_q = quantize_fp8(x, UB)
wqkv_q = quantize_fp8(wqkv, UB)
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
y = attn_linear(x, wqkv_q)
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
y_ref = (x @ wqkv.T).to(torch.bfloat16)
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()