forked from phoenix-oss/llama-stack-mirror
impls
-> inline
, adapters
-> remote
(#381)
This commit is contained in:
parent
b10e9f46bb
commit
994732e2e0
169 changed files with 106 additions and 105 deletions
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
print("Using efficient FP8 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (
|
||||
isinstance(w1, Fp8ScaledWeights)
|
||||
and isinstance(w3, Fp8ScaledWeights)
|
||||
and isinstance(w2, Fp8ScaledWeights)
|
||||
):
|
||||
return ffn_swiglu_fp8_dynamic(
|
||||
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
||||
)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||
scale=w_scale.to(device="cuda"),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||
"""
|
||||
if isinstance(w, Fp8RowwiseWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x, num_tokens, activation_scale_ub
|
||||
)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
xq, w.weight, x_scale, w.scale, use_fast_accum=True
|
||||
)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
return z_.view(B, T, D)
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
"Skip when H100 is not available",
|
||||
)
|
||||
class FP8Tests(unittest.TestCase):
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
D=st.sampled_from([4096, 8192]),
|
||||
HD_L=st.sampled_from([1280, 2560]),
|
||||
B=st.sampled_from([1, 2]),
|
||||
T=st.sampled_from([2048, 4096]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_ffn(
|
||||
self,
|
||||
D: int, # noqa
|
||||
HD_L: int,
|
||||
B: int,
|
||||
T: int,
|
||||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
|
||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||
|
||||
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
|
||||
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
|
||||
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
|
||||
|
||||
v_ref = ref_ffn(x, w1, w3, w2)
|
||||
|
||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Hadamard transform.
|
||||
|
||||
This function performs the Hadamard transform on the input tensor 'x'.
|
||||
The Hadamard transform is a linear transformation that multiplies the input
|
||||
tensor by the Hadamard matrix of dimension n x n, where n is the size of
|
||||
the last dimension of the input tensor.
|
||||
"""
|
||||
*_, n = x.shape
|
||||
m = int(math.log2(n))
|
||||
assert n == 1 << m, "n must be a power of 2"
|
||||
x = x[..., None]
|
||||
inv_sqrt2 = 0.5**0.5
|
||||
for _ in range(m):
|
||||
top = x[..., ::2, :] + x[..., 1::2, :]
|
||||
bot = x[..., ::2, :] - x[..., 1::2, :]
|
||||
x = torch.cat((top, bot), dim=-1)
|
||||
x *= inv_sqrt2
|
||||
res = x.squeeze(-2)
|
||||
return res
|
||||
|
||||
|
||||
class HadamardModule(torch.nn.Module):
|
||||
"""A module that applies the Hadamard transform to the input tensor.
|
||||
|
||||
Args:
|
||||
group_size: The size of the groups that the input tensor will be divided into
|
||||
before applying the Hadamard transform.
|
||||
"""
|
||||
|
||||
def __init__(self, group_size: int) -> None:
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
reshape_back = False
|
||||
orig_shape = x.shape
|
||||
if self.group_size != x.shape[-1]:
|
||||
reshape_back = True
|
||||
x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size)
|
||||
x = hadamard_transform(x)
|
||||
if reshape_back:
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def add_hadamard_transform_for_spinquant(
|
||||
model: torch.nn.Module, prefix: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
|
||||
This function recursively traverses the model's children and looks for layers that match the pattern
|
||||
"layers.<digit>.feed_forward.w2", where <digit> is one or more digits. When such a layer is found,
|
||||
it is replaced with a new sequential module that consists of a HadamardModule followed by the original
|
||||
layer. The HadamardModule applies the Hadamard transform to the input tensor.
|
||||
|
||||
See `SpinQuant <https://arxiv.org/abs/2405.16406>_` paper for more details.
|
||||
|
||||
Args:
|
||||
model: An instance of 'torch.nn.Module' (e.g., Transformer model).
|
||||
prefix: A string prefix to add to the full name of each child module.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2"
|
||||
for module_name, module in model.named_children():
|
||||
child_full_name = prefix + "." + module_name
|
||||
if re.search(pattern_last_linear_ffn, child_full_name):
|
||||
new_module = nn.Sequential(
|
||||
HadamardModule(group_size=module.in_features), module
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, new_module)
|
||||
else:
|
||||
add_hadamard_transform_for_spinquant(
|
||||
module, (prefix + "." if prefix else prefix) + module_name
|
||||
)
|
|
@ -0,0 +1,339 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from .fp8_impls import ffn_swiglu
|
||||
|
||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
return model
|
||||
|
||||
elif config.quantization.type != QuantizationType.fp8.value:
|
||||
raise ValueError("Only FP8 quantization is supported")
|
||||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
llama_model = resolve_model(config.model)
|
||||
assert llama_model is not None, f"Model {config.model} not found"
|
||||
|
||||
# Move weights to GPU with quantization
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
cprint("Loading fp8 scales...", "yellow")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
fp8_scales_path
|
||||
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = load_fp8(
|
||||
param.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = quantize_fp8(
|
||||
param.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
return model
|
||||
|
||||
|
||||
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||
"""
|
||||
Int8DynActInt4WeightLinear with LoRA adaptor.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
device: Device to use.
|
||||
group_size: Group size for quantization.
|
||||
precision: Precision of quantization.
|
||||
scales_precision: Precision of scales.
|
||||
lora_rank: Rank of LoRA adaptor.
|
||||
lora_scale: Scale of LoRA adaptor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias=False,
|
||||
device=None,
|
||||
# quantization parameters
|
||||
group_size: int = 256,
|
||||
precision: torch.dtype = torch.float32,
|
||||
scales_precision: torch.dtype = torch.float32,
|
||||
# LoRA parameters
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
groupsize=group_size,
|
||||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
self.adaptor = nn.Sequential()
|
||||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||
self.lora_scale = lora_scale
|
||||
else:
|
||||
self.adaptor = None
|
||||
self.lora_scale = None
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized weights from the state dict."""
|
||||
if prefix + "zeros" not in state_dict:
|
||||
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
||||
assert prefix + "scales" in state_dict
|
||||
state_dict[prefix + "zeros"] = torch.zeros_like(
|
||||
state_dict[prefix + "scales"]
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
module_out = super().forward(input_)
|
||||
if self.adaptor is not None:
|
||||
adaptor_out = self.adaptor(input_) * self.lora_scale
|
||||
return module_out + adaptor_out
|
||||
return module_out
|
||||
|
||||
|
||||
class Int8WeightEmbedding(torch.nn.Embedding):
|
||||
"""An embedding layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
num_embeddings: Number of embeddings.
|
||||
embedding_dim: Embedding dimension.
|
||||
padding_idx: Padding index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int,
|
||||
device=None,
|
||||
) -> None:
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
class Int8WeightLinear(torch.nn.Linear):
|
||||
"""A linear layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = True, device=None
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model: torch.nn.Module,
|
||||
group_size: int,
|
||||
lora_rank: Optional[int],
|
||||
lora_scale: Optional[float],
|
||||
):
|
||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||
|
||||
Note that the weights of embedding and output layers are quantized to int8.
|
||||
"""
|
||||
device = None
|
||||
for module_name, module in model.named_children():
|
||||
if module_name == "output":
|
||||
quantized_module = Int8WeightLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif module_name == "tok_embeddings":
|
||||
quantized_module = Int8WeightEmbedding(
|
||||
num_embeddings=module.num_embeddings,
|
||||
embedding_dim=module.embedding_dim,
|
||||
padding_idx=module.padding_idx,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=False,
|
||||
group_size=group_size,
|
||||
lora_rank=lora_rank,
|
||||
lora_scale=lora_scale,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
else:
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
module, group_size, lora_rank, lora_scale
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_to_int4_quantized_model(
|
||||
model: Transformer,
|
||||
model_args: ModelArgs,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
) -> Transformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
quantization_args = model_args.quantization_args
|
||||
|
||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
||||
)
|
||||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"'group_size' cannot be None in 'quantization_args'. Please specify it."
|
||||
)
|
||||
|
||||
if model_args.lora_args is None:
|
||||
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
||||
lora_rank = None
|
||||
lora_scale = None
|
||||
else:
|
||||
lora_rank = model_args.lora_args.rank
|
||||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model, group_size, lora_rank, lora_scale
|
||||
)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return model.to(device)
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,36 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "Error: Please provide the name of CONDA environment you wish to create"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ENV_NAME=$1
|
||||
|
||||
set -eu
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
||||
echo "Will build env (or overwrite) named '$ENV_NAME'"
|
||||
|
||||
set -x
|
||||
|
||||
run_build() {
|
||||
# Set up the conda environment
|
||||
yes | conda remove --name $ENV_NAME --all
|
||||
yes | conda create -n $ENV_NAME python=3.10
|
||||
conda activate $ENV_NAME
|
||||
|
||||
# PT nightly
|
||||
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||
|
||||
# install dependencies for `llama-agentic-system`
|
||||
pip install -r fp8_requirements.txt
|
||||
}
|
||||
|
||||
run_build
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
|
||||
|
||||
from llama.model import ModelArgs, Transformer, TransformerBlock
|
||||
from llama.tokenizer import Tokenizer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_dir: str,
|
||||
tokenizer_path: str,
|
||||
quantized_ckpt_dir: str,
|
||||
max_seq_len: Optional[int] = 512,
|
||||
max_batch_size: Optional[int] = 4,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
seed: int = 1,
|
||||
):
|
||||
""" """
|
||||
if not os.path.exists(quantized_ckpt_dir):
|
||||
os.makedirs(quantized_ckpt_dir)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "params.json"),
|
||||
os.path.join(quantized_ckpt_dir, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "tokenizer.model"),
|
||||
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
if not model_parallel_is_initialized():
|
||||
if model_parallel_size is None:
|
||||
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
print(ckpt_path)
|
||||
assert (
|
||||
quantized_ckpt_dir is not None
|
||||
), "QUantized checkpoint directory should not be None"
|
||||
fp8_scales = {}
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
|
||||
fp8_scales_path = os.path.join(
|
||||
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
torch.save(fp8_scales, fp8_scales_path)
|
||||
|
||||
ckpt_path = os.path.join(
|
||||
quantized_ckpt_dir,
|
||||
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
|
||||
)
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
cd $(git rev-parse --show-toplevel)
|
||||
|
||||
MASTER_HOST=$1
|
||||
RUN_ID=$2
|
||||
CKPT_DIR=$3
|
||||
QUANT_CKPT_DIR=$4
|
||||
TOKENIZER_PATH=$5
|
||||
NNODES=$6
|
||||
NPROC=$7
|
||||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
--rdzv_conf='timeout=120' \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint="${MASTER_HOST}:29502" \
|
||||
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR
|
Loading…
Add table
Add a link
Reference in a new issue