mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
clean up
This commit is contained in:
parent
c8a0b110c0
commit
5ecedc12e7
2 changed files with 2 additions and 18 deletions
|
@ -10,7 +10,7 @@ from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from torch import nn, Tensor
|
from torch import Tensor, nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...datatypes import QuantizationMode
|
from ...datatypes import QuantizationMode
|
||||||
|
|
|
@ -65,7 +65,7 @@ class Int4Weights(
|
||||||
Int4ScaledWeights,
|
Int4ScaledWeights,
|
||||||
collections.namedtuple(
|
collections.namedtuple(
|
||||||
"Int4Weights",
|
"Int4Weights",
|
||||||
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
|
["weight", "scale", "zero_point", "shape"],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
@ -184,20 +184,13 @@ def quantize_fp8(
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def quantize_int4(
|
def quantize_int4(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: Optional[torch.device] = None,
|
||||||
) -> Int4Weights:
|
) -> Int4Weights:
|
||||||
"""Quantize [n, k/2] weight tensor.
|
"""Quantize [n, k/2] weight tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
||||||
"""
|
"""
|
||||||
activation_scale_ub = torch.tensor(
|
|
||||||
[fp8_activation_scale_ub],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=output_device,
|
|
||||||
)
|
|
||||||
if w.ndim >= 3:
|
if w.ndim >= 3:
|
||||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||||
|
@ -212,7 +205,6 @@ def quantize_int4(
|
||||||
scale=scale.to(output_device),
|
scale=scale.to(output_device),
|
||||||
zero_point=zero_point.to(output_device),
|
zero_point=zero_point.to(output_device),
|
||||||
shape=wq.shape,
|
shape=wq.shape,
|
||||||
activation_scale_ub=activation_scale_ub,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -247,26 +239,18 @@ def load_int4(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
zero_point: Tensor,
|
zero_point: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: Optional[torch.device] = None,
|
||||||
) -> Int4Weights:
|
) -> Int4Weights:
|
||||||
"""Load INT4 [n, k/2] weight tensor.
|
"""Load INT4 [n, k/2] weight tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (Tensor): [n, k/2] input INT4.
|
w (Tensor): [n, k/2] input INT4.
|
||||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
||||||
"""
|
"""
|
||||||
activation_scale_ub = torch.tensor(
|
|
||||||
[fp8_activation_scale_ub],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=output_device,
|
|
||||||
)
|
|
||||||
return Int4Weights(
|
return Int4Weights(
|
||||||
weight=w.to(torch.int8).to(device=output_device),
|
weight=w.to(torch.int8).to(device=output_device),
|
||||||
scale=scale.to(device=output_device),
|
scale=scale.to(device=output_device),
|
||||||
zero_point=zero_point.to(device=output_device),
|
zero_point=zero_point.to(device=output_device),
|
||||||
shape=w.shape,
|
shape=w.shape,
|
||||||
activation_scale_ub=activation_scale_ub,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue