cleanup for fp8 and requirements etc

This commit is contained in:
Ashwin Bharambe 2024-07-20 23:21:41 -07:00
parent 2428701951
commit d73fed5cc3
9 changed files with 78 additions and 905 deletions

View file

@ -5,7 +5,6 @@ import os
from typing import Optional
import torch
from torch import Tensor
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from models.llama3_1.api.model import Transformer, TransformerBlock
@ -17,6 +16,7 @@ from toolchain.inference.api.config import (
InlineImplConfig,
)
from toolchain.inference.api.datatypes import QuantizationType
from torch import Tensor
def is_fbgemm_available() -> bool:
@ -69,27 +69,15 @@ def convert_to_quantized_model(
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w3.weight = load_fp8(
block.feed_forward.w3.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w2.weight = load_fp8(
block.feed_forward.w2.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
@ -97,21 +85,13 @@ def convert_to_quantized_model(
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w3.weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w2.weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):