Initial commit

This commit is contained in:
Ashwin Bharambe 2024-06-25 15:47:57 -07:00 committed by Ashwin Bharambe
commit 5d5acc8ed5
81 changed files with 4458 additions and 0 deletions

View file

@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
from typing import Optional, Type
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
from torch import nn, Tensor
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
def ffn_swiglu(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(
x: Tensor,
w: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
del xq
return y
def ffn_swiglu_fp8_dynamic(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
(B, T, D) = x.shape # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_fp8_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
return z_.view(B, T, D)

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
from torch import Tensor
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model

View file

@ -0,0 +1,30 @@
#!/bin/bash
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1
fi
ENV_NAME=$1
set -eu
eval "$(conda shell.bash hook)"
echo "Will build env (or overwrite) named '$ENV_NAME'"
set -x
run_build() {
# Set up the conda environment
yes | conda remove --name $ENV_NAME --all
yes | conda create -n $ENV_NAME python=3.10
conda activate $ENV_NAME
# PT nightly
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
# install dependencies for `llama-agentic-system`
pip install -r fp8_requirements.txt
}
run_build

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from torch.nn.parameter import Parameter
def main(
ckpt_dir: str,
tokenizer_path: str,
quantized_ckpt_dir: str,
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
""" """
if not os.path.exists(quantized_ckpt_dir):
os.makedirs(quantized_ckpt_dir)
shutil.copy(
os.path.join(ckpt_dir, "params.json"),
os.path.join(quantized_ckpt_dir, "params.json"),
)
shutil.copy(
os.path.join(ckpt_dir, "tokenizer.model"),
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(
quantized_ckpt_dir,
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
)
torch.save(model.state_dict(), ckpt_path)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
MASTER_HOST=$1
RUN_ID=$2
CKPT_DIR=$3
QUANT_CKPT_DIR=$4
TOKENIZER_PATH=$5
NNODES=$6
NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \
--rdzv_conf='timeout=120' \
--rdzv_backend=c10d \
--rdzv_endpoint="${MASTER_HOST}:29502" \
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):
@settings(deadline=None)
@given(
D=st.sampled_from([4096, 8192]),
HD_L=st.sampled_from([1280, 2560]),
B=st.sampled_from([1, 2]),
T=st.sampled_from([2048, 4096]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_ffn(
self,
D: int, # noqa
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
if __name__ == "__main__":
unittest.main()