mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
184 lines
4.8 KiB
Python
184 lines
4.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
|
|
import collections
|
|
from typing import Optional, Type
|
|
|
|
try:
|
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
|
|
print("Using efficient FP8 operators in FBGEMM.")
|
|
except ImportError:
|
|
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
|
raise
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
|
|
|
|
class Fp8ScaledWeights:
|
|
# TODO: Ugly trick so torch allows us to replace parameters
|
|
# with our custom Fp8Weights instance. Do this properly.
|
|
@property
|
|
def __class__(self) -> Type[nn.parameter.Parameter]:
|
|
return nn.Parameter
|
|
|
|
@property
|
|
def grad_fn(self) -> None:
|
|
return None
|
|
|
|
|
|
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
|
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
|
class Fp8RowwiseWeights(
|
|
Fp8ScaledWeights,
|
|
collections.namedtuple(
|
|
"Fp8RowwiseWeights",
|
|
["weight", "scale", "shape", "activation_scale_ub"],
|
|
),
|
|
):
|
|
pass
|
|
|
|
|
|
def ffn_swiglu(
|
|
x: Tensor,
|
|
w1: Fp8RowwiseWeights,
|
|
w3: Fp8RowwiseWeights,
|
|
w2: Fp8RowwiseWeights,
|
|
num_tokens: Optional[Tensor] = None,
|
|
is_memory_bounded: bool = False,
|
|
) -> Tensor:
|
|
if (
|
|
isinstance(w1, Fp8ScaledWeights)
|
|
and isinstance(w3, Fp8ScaledWeights)
|
|
and isinstance(w2, Fp8ScaledWeights)
|
|
):
|
|
return ffn_swiglu_fp8_dynamic(
|
|
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
|
)
|
|
|
|
(B, T, D) = x.shape # noqa: N806
|
|
(HD_L, D_) = w1.shape # noqa: N806
|
|
assert D_ == D
|
|
|
|
assert isinstance(w1, Tensor)
|
|
assert isinstance(w3, Tensor)
|
|
x1 = x.view(B * T, D) @ w1.T
|
|
x2 = x.view(B * T, D) @ w3.T
|
|
z = torch.nn.functional.silu(x1) * x2
|
|
del x1, x2
|
|
assert isinstance(w2, Tensor)
|
|
return (z @ w2.T).view(B, T, D)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def quantize_fp8(
|
|
w: Tensor,
|
|
fp8_activation_scale_ub: float,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> Fp8RowwiseWeights:
|
|
"""Quantize [n, k] weight tensor.
|
|
|
|
Args:
|
|
w (Tensor): [n, k] input high precision tensor to quantize.
|
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
"""
|
|
activation_scale_ub = torch.tensor(
|
|
[fp8_activation_scale_ub],
|
|
dtype=torch.float,
|
|
device="cuda",
|
|
)
|
|
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
|
del w
|
|
return Fp8RowwiseWeights(
|
|
weight=wq,
|
|
scale=w_scale,
|
|
shape=wq.shape,
|
|
activation_scale_ub=activation_scale_ub,
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def load_fp8(
|
|
w: Tensor,
|
|
w_scale: Tensor,
|
|
fp8_activation_scale_ub: float,
|
|
) -> Fp8RowwiseWeights:
|
|
"""Load FP8 [n, k] weight tensor.
|
|
|
|
Args:
|
|
w (Tensor): [n, k] input FP8.
|
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
"""
|
|
activation_scale_ub = torch.tensor(
|
|
[fp8_activation_scale_ub],
|
|
dtype=torch.float,
|
|
device="cuda",
|
|
)
|
|
return Fp8RowwiseWeights(
|
|
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
|
scale=w_scale.to(device="cuda"),
|
|
shape=w.shape,
|
|
activation_scale_ub=activation_scale_ub,
|
|
)
|
|
|
|
|
|
def fc_fp8_dynamic(
|
|
x: Tensor,
|
|
w: Fp8RowwiseWeights,
|
|
activation_scale_ub: Optional[Tensor] = None,
|
|
num_tokens: Optional[Tensor] = None,
|
|
is_memory_bounded: bool = False,
|
|
) -> Tensor:
|
|
"""
|
|
Single w8a8 fc layer with dynamic row-wise scaling.
|
|
"""
|
|
if isinstance(w, Fp8RowwiseWeights):
|
|
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
|
x, num_tokens, activation_scale_ub
|
|
)
|
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
xq, w.weight, x_scale, w.scale, use_fast_accum=True
|
|
)
|
|
del xq
|
|
return y
|
|
|
|
|
|
def ffn_swiglu_fp8_dynamic(
|
|
x: Tensor,
|
|
w1: Fp8RowwiseWeights,
|
|
w3: Fp8RowwiseWeights,
|
|
w2: Fp8RowwiseWeights,
|
|
activation_scale_ub: Optional[Tensor] = None,
|
|
num_tokens: Optional[Tensor] = None,
|
|
is_memory_bounded: bool = False,
|
|
) -> Tensor:
|
|
(B, T, D) = x.shape # noqa: N806
|
|
HD_L = w1.shape[0] # noqa: N806
|
|
assert HD_L == w3.shape[0]
|
|
x1 = fc_fp8_dynamic(
|
|
x.view(B * T, D),
|
|
w1,
|
|
activation_scale_ub,
|
|
num_tokens,
|
|
is_memory_bounded,
|
|
)
|
|
x2 = fc_fp8_dynamic(
|
|
x.view(B * T, D),
|
|
w3,
|
|
activation_scale_ub,
|
|
num_tokens,
|
|
is_memory_bounded,
|
|
)
|
|
z = torch.nn.functional.silu(x1) * x2
|
|
del x1, x2
|
|
|
|
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
|
|
|
return z_.view(B, T, D)
|