make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -2,7 +2,6 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
from enum import Enum, unique
from typing import Optional, Type
try:
@ -11,20 +10,12 @@ try:
print("Using efficient FP8 operators in FBGEMM.")
except (ImportError, ModuleNotFoundError):
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
from torch import nn, Tensor
@unique
class FfnQuantizeMode(Enum):
FP8_ROWWISE = "fp8_rowwise"
NONE = "none"
def __str__(self) -> str:
return self.value
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@ -84,7 +75,6 @@ def ffn_swiglu(
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
mode: Optional[FfnQuantizeMode] = None,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
@ -92,22 +82,45 @@ def quantize_fp8(
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
mode (FfnQuantizeMode): Quantization mode.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(