mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -2,7 +2,6 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
from enum import Enum, unique
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
|
@ -11,20 +10,12 @@ try:
|
|||
print("Using efficient FP8 operators in FBGEMM.")
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
@unique
|
||||
class FfnQuantizeMode(Enum):
|
||||
FP8_ROWWISE = "fp8_rowwise"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
|
@ -84,7 +75,6 @@ def ffn_swiglu(
|
|||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
mode: Optional[FfnQuantizeMode] = None,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
@ -92,22 +82,45 @@ def quantize_fp8(
|
|||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
mode (FfnQuantizeMode): Quantization mode.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||
scale=w_scale.to(device="cuda"),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_fp8_dynamic(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue