mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
315 lines
8.9 KiB
Python
315 lines
8.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# type: ignore
|
|
import collections
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
try:
|
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
|
|
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
|
except ImportError:
|
|
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
|
raise
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
|
|
class Fp8ScaledWeights:
|
|
# TODO: Ugly trick so torch allows us to replace parameters
|
|
# with our custom Fp8Weights instance. Do this properly.
|
|
@property
|
|
def __class__(self) -> type[nn.parameter.Parameter]:
|
|
return nn.Parameter
|
|
|
|
@property
|
|
def grad_fn(self) -> None:
|
|
return None
|
|
|
|
|
|
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
|
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
|
class Fp8RowwiseWeights(
|
|
Fp8ScaledWeights,
|
|
collections.namedtuple(
|
|
"Fp8RowwiseWeights",
|
|
["weight", "scale", "shape", "activation_scale_ub"],
|
|
),
|
|
):
|
|
pass
|
|
|
|
|
|
class Int4ScaledWeights:
|
|
# TODO: Ugly trick so torch allows us to replace parameters
|
|
# with our custom Int4Weights instance. Do this properly.
|
|
@property
|
|
def __class__(self) -> type[nn.parameter.Parameter]:
|
|
return nn.Parameter
|
|
|
|
@property
|
|
def grad_fn(self) -> None:
|
|
return None
|
|
|
|
|
|
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
|
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
|
class Int4Weights(
|
|
Int4ScaledWeights,
|
|
collections.namedtuple(
|
|
"Int4Weights",
|
|
["weight", "scale", "zero_point", "shape"],
|
|
),
|
|
):
|
|
pass
|
|
|
|
|
|
def int4_row_quantize(
|
|
x: torch.Tensor,
|
|
group_size: int = 128,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
n_bit = 4 # Number of target bits.
|
|
to_quant = x.reshape(-1, group_size).to(torch.float)
|
|
|
|
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
max_int = 2**n_bit - 1
|
|
min_int = 0
|
|
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
|
|
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
|
|
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
|
|
|
# Recenter output and move to int8.
|
|
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
|
|
|
# Cutlass expects column major layout for scale and zero point,
|
|
# so we transpose here and make them contiguous.
|
|
scales = scales.view(x.shape[0], -1).t().contiguous()
|
|
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
|
|
|
return out, scales, zeros
|
|
|
|
|
|
def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
|
# Given int8 x, pack adjacent int4 values into a single int8.
|
|
low_x = x[:, ::2]
|
|
high_x = x[:, 1::2]
|
|
|
|
# High bits need to left shift, this also masks off extra bits.
|
|
high_x = torch.bitwise_left_shift(high_x, 4)
|
|
# Low bits need to have sign bits removed.
|
|
low_x = torch.bitwise_and(low_x, 0xF)
|
|
|
|
# Recombine into a single value with bitwise or.
|
|
return torch.bitwise_or(low_x, high_x).contiguous()
|
|
|
|
|
|
def bmm_nt(
|
|
x: Tensor,
|
|
w: Fp8RowwiseWeights | Int4Weights,
|
|
num_tokens: Tensor | None = None,
|
|
) -> Tensor:
|
|
if isinstance(w, Fp8ScaledWeights):
|
|
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
|
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
|
|
elif isinstance(w, Int4ScaledWeights):
|
|
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
|
|
else:
|
|
raise ValueError("Unsupported quantization type")
|
|
|
|
|
|
def ffn_swiglu(
|
|
x: Tensor,
|
|
w1: Fp8RowwiseWeights | Int4Weights,
|
|
w3: Fp8RowwiseWeights | Int4Weights,
|
|
w2: Fp8RowwiseWeights | Int4Weights,
|
|
num_tokens: Tensor | None = None,
|
|
is_memory_bounded: bool = False,
|
|
) -> Tensor:
|
|
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
|
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
|
|
):
|
|
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
|
|
|
(B, T, D) = x.shape # noqa: N806
|
|
(HD_L, D_) = w1.shape # noqa: N806
|
|
assert D_ == D
|
|
|
|
assert isinstance(w1, Tensor)
|
|
assert isinstance(w3, Tensor)
|
|
x1 = x.view(B * T, D) @ w1.T
|
|
x2 = x.view(B * T, D) @ w3.T
|
|
z = torch.nn.functional.silu(x1) * x2
|
|
del x1, x2
|
|
assert isinstance(w2, Tensor)
|
|
return (z @ w2.T).view(B, T, D)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def quantize_fp8(
|
|
w: Tensor,
|
|
fp8_activation_scale_ub: float,
|
|
output_device: torch.device | None = None,
|
|
) -> Fp8RowwiseWeights:
|
|
"""Quantize [n, k] weight tensor.
|
|
|
|
Args:
|
|
w (Tensor): [n, k] input high precision tensor to quantize.
|
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
"""
|
|
activation_scale_ub = torch.tensor(
|
|
[fp8_activation_scale_ub],
|
|
dtype=torch.float,
|
|
device=output_device,
|
|
)
|
|
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
|
del w
|
|
return Fp8RowwiseWeights(
|
|
weight=wq,
|
|
scale=w_scale,
|
|
shape=wq.shape,
|
|
activation_scale_ub=activation_scale_ub,
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def quantize_int4(
|
|
w: Tensor,
|
|
output_device: torch.device | None = None,
|
|
) -> Int4Weights:
|
|
"""Quantize [n, k/2] weight tensor.
|
|
|
|
Args:
|
|
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
|
"""
|
|
if w.ndim >= 3:
|
|
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
|
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
|
scale = torch.stack(scale, dim=0)
|
|
zero_point = torch.stack(zero_point, dim=0)
|
|
else:
|
|
wq, scale, zero_point = int4_row_quantize(w)
|
|
wq = pack_int4(wq)
|
|
del w
|
|
return Int4Weights(
|
|
weight=wq.to(output_device),
|
|
scale=scale.to(output_device),
|
|
zero_point=zero_point.to(output_device),
|
|
shape=wq.shape,
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def load_fp8(
|
|
w: Tensor,
|
|
w_scale: Tensor,
|
|
fp8_activation_scale_ub: float,
|
|
output_device: torch.device | None = None,
|
|
) -> Fp8RowwiseWeights:
|
|
"""Load FP8 [n, k] weight tensor.
|
|
|
|
Args:
|
|
w (Tensor): [n, k] input FP8.
|
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
"""
|
|
activation_scale_ub = torch.tensor(
|
|
[fp8_activation_scale_ub],
|
|
dtype=torch.float,
|
|
device=output_device,
|
|
)
|
|
return Fp8RowwiseWeights(
|
|
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
|
|
scale=w_scale.to(device=output_device),
|
|
shape=w.shape,
|
|
activation_scale_ub=activation_scale_ub,
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def load_int4(
|
|
w: Tensor,
|
|
scale: Tensor,
|
|
zero_point: Tensor,
|
|
output_device: torch.device | None = None,
|
|
) -> Int4Weights:
|
|
"""Load INT4 [n, k/2] weight tensor.
|
|
|
|
Args:
|
|
w (Tensor): [n, k/2] input INT4.
|
|
"""
|
|
return Int4Weights(
|
|
weight=w.to(torch.int8).to(device=output_device),
|
|
scale=scale.to(device=output_device),
|
|
zero_point=zero_point.to(device=output_device),
|
|
shape=w.shape,
|
|
)
|
|
|
|
|
|
def fc_dynamic(
|
|
x: Tensor,
|
|
w: Fp8RowwiseWeights | Int4Weights,
|
|
activation_scale_ub: Tensor | None = None,
|
|
num_tokens: Tensor | None = None,
|
|
is_memory_bounded: bool = False,
|
|
) -> Tensor:
|
|
"""
|
|
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
|
|
"""
|
|
if isinstance(w, Int4Weights):
|
|
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
|
|
else:
|
|
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
|
del xq
|
|
return y
|
|
|
|
|
|
def ffn_swiglu_dynamic(
|
|
x: Tensor,
|
|
w1: Fp8RowwiseWeights | Int4Weights,
|
|
w3: Fp8RowwiseWeights | Int4Weights,
|
|
w2: Fp8RowwiseWeights | Int4Weights,
|
|
activation_scale_ub: Tensor | None = None,
|
|
num_tokens: Tensor | None = None,
|
|
is_memory_bounded: bool = False,
|
|
) -> Tensor:
|
|
assert x.dim() == 3 or x.dim() == 2
|
|
if x.dim() == 3:
|
|
(B, T, D) = x.shape # noqa: N806
|
|
else:
|
|
(T, D) = x.shape # noqa: N806
|
|
B = 1 # noqa: N806
|
|
|
|
HD_L = w1.shape[0] # noqa: N806
|
|
assert HD_L == w3.shape[0]
|
|
x1 = fc_dynamic(
|
|
x.view(B * T, D),
|
|
w1,
|
|
activation_scale_ub,
|
|
num_tokens,
|
|
is_memory_bounded,
|
|
)
|
|
x2 = fc_dynamic(
|
|
x.view(B * T, D),
|
|
w3,
|
|
activation_scale_ub,
|
|
num_tokens,
|
|
is_memory_bounded,
|
|
)
|
|
z = torch.nn.functional.silu(x1) * x2
|
|
del x1, x2
|
|
|
|
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
|
|
|
if x.dim() == 3:
|
|
return z_.view(B, T, D)
|
|
else:
|
|
return z_
|