llama-stack-mirror/llama_stack/models/llama/quantize_impls.py
Ihar Hrachyshka 9e6561a1ec
chore: enable pyupgrade fixes (#1806)
# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
2025-05-01 14:23:50 -07:00

315 lines
8.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# type: ignore
import collections
import logging
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
except ImportError:
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
raise
import torch
from torch import Tensor, nn
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
class Int4ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Int4Weights instance. Do this properly.
@property
def __class__(self) -> type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Int4Weights(
Int4ScaledWeights,
collections.namedtuple(
"Int4Weights",
["weight", "scale", "zero_point", "shape"],
),
):
pass
def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
# Recenter output and move to int8.
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
# Cutlass expects column major layout for scale and zero point,
# so we transpose here and make them contiguous.
scales = scales.view(x.shape[0], -1).t().contiguous()
zeros = zeros.view(x.shape[0], -1).t().contiguous()
return out, scales, zeros
def pack_int4(x: torch.Tensor) -> torch.Tensor:
# Given int8 x, pack adjacent int4 values into a single int8.
low_x = x[:, ::2]
high_x = x[:, 1::2]
# High bits need to left shift, this also masks off extra bits.
high_x = torch.bitwise_left_shift(high_x, 4)
# Low bits need to have sign bits removed.
low_x = torch.bitwise_and(low_x, 0xF)
# Recombine into a single value with bitwise or.
return torch.bitwise_or(low_x, high_x).contiguous()
def bmm_nt(
x: Tensor,
w: Fp8RowwiseWeights | Int4Weights,
num_tokens: Tensor | None = None,
) -> Tensor:
if isinstance(w, Fp8ScaledWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
elif isinstance(w, Int4ScaledWeights):
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
else:
raise ValueError("Unsupported quantization type")
def ffn_swiglu(
x: Tensor,
w1: Fp8RowwiseWeights | Int4Weights,
w3: Fp8RowwiseWeights | Int4Weights,
w2: Fp8RowwiseWeights | Int4Weights,
num_tokens: Tensor | None = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
):
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: torch.device | None = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def quantize_int4(
w: Tensor,
output_device: torch.device | None = None,
) -> Int4Weights:
"""Quantize [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input high precision tensor to quantize.
"""
if w.ndim >= 3:
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
scale = torch.stack(scale, dim=0)
zero_point = torch.stack(zero_point, dim=0)
else:
wq, scale, zero_point = int4_row_quantize(w)
wq = pack_int4(wq)
del w
return Int4Weights(
weight=wq.to(output_device),
scale=scale.to(output_device),
zero_point=zero_point.to(output_device),
shape=wq.shape,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
output_device: torch.device | None = None,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
scale=w_scale.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_int4(
w: Tensor,
scale: Tensor,
zero_point: Tensor,
output_device: torch.device | None = None,
) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input INT4.
"""
return Int4Weights(
weight=w.to(torch.int8).to(device=output_device),
scale=scale.to(device=output_device),
zero_point=zero_point.to(device=output_device),
shape=w.shape,
)
def fc_dynamic(
x: Tensor,
w: Fp8RowwiseWeights | Int4Weights,
activation_scale_ub: Tensor | None = None,
num_tokens: Tensor | None = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
"""
if isinstance(w, Int4Weights):
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
else:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y
def ffn_swiglu_dynamic(
x: Tensor,
w1: Fp8RowwiseWeights | Int4Weights,
w3: Fp8RowwiseWeights | Int4Weights,
w2: Fp8RowwiseWeights | Int4Weights,
activation_scale_ub: Tensor | None = None,
num_tokens: Tensor | None = None,
is_memory_bounded: bool = False,
) -> Tensor:
assert x.dim() == 3 or x.dim() == 2
if x.dim() == 3:
(B, T, D) = x.shape # noqa: N806
else:
(T, D) = x.shape # noqa: N806
B = 1 # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
if x.dim() == 3:
return z_.view(B, T, D)
else:
return z_