mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
clean up
This commit is contained in:
parent
c8a0b110c0
commit
5ecedc12e7
2 changed files with 2 additions and 18 deletions
|
@ -10,7 +10,7 @@ from typing import Callable, Optional
|
|||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from torch import nn, Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
|
|
|
@ -65,7 +65,7 @@ class Int4Weights(
|
|||
Int4ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Int4Weights",
|
||||
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
|
||||
["weight", "scale", "zero_point", "shape"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
@ -184,20 +184,13 @@ def quantize_fp8(
|
|||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
if w.ndim >= 3:
|
||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||
|
@ -212,7 +205,6 @@ def quantize_int4(
|
|||
scale=scale.to(output_device),
|
||||
zero_point=zero_point.to(output_device),
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
|
@ -247,26 +239,18 @@ def load_int4(
|
|||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input INT4.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Int4Weights(
|
||||
weight=w.to(torch.int8).to(device=output_device),
|
||||
scale=scale.to(device=output_device),
|
||||
zero_point=zero_point.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue