mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
cleanup for fp8 and requirements etc
This commit is contained in:
parent
2428701951
commit
d73fed5cc3
9 changed files with 78 additions and 905 deletions
|
@ -5,7 +5,6 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from models.llama3_1.api.model import Transformer, TransformerBlock
|
||||
|
@ -17,6 +16,7 @@ from toolchain.inference.api.config import (
|
|||
InlineImplConfig,
|
||||
)
|
||||
from toolchain.inference.api.datatypes import QuantizationType
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def is_fbgemm_available() -> bool:
|
||||
|
@ -69,27 +69,15 @@ def convert_to_quantized_model(
|
|||
continue
|
||||
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
block.feed_forward.w1.weight = load_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
block.feed_forward.w3.weight = load_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
block.feed_forward.w2.weight = load_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = load_fp8(
|
||||
param.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
||||
for block in model.layers:
|
||||
|
@ -97,21 +85,13 @@ def convert_to_quantized_model(
|
|||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
block.feed_forward.w1.weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
block.feed_forward.w3.weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
block.feed_forward.w2.weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = quantize_fp8(
|
||||
param.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue