mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 18:32:40 +00:00
Fix fp8 quantization script.
This commit is contained in:
parent
e84d4436b5
commit
511b054b77
2 changed files with 6 additions and 9 deletions
|
|
@ -22,10 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
|
from fp8_impls import quantize_fp8
|
||||||
|
|
||||||
from llama.model import ModelArgs, Transformer, TransformerBlock
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,7 +37,6 @@ def main(
|
||||||
max_seq_len: Optional[int] = 512,
|
max_seq_len: Optional[int] = 512,
|
||||||
max_batch_size: Optional[int] = 4,
|
max_batch_size: Optional[int] = 4,
|
||||||
model_parallel_size: Optional[int] = None,
|
model_parallel_size: Optional[int] = None,
|
||||||
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
|
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
):
|
):
|
||||||
|
|
@ -112,7 +112,6 @@ def main(
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w1.weight,
|
block.feed_forward.w1.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
ffn_quantize_mode,
|
|
||||||
output_device=torch.device("cpu"),
|
output_device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
@ -124,7 +123,6 @@ def main(
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w3.weight,
|
block.feed_forward.w3.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
ffn_quantize_mode,
|
|
||||||
output_device=torch.device("cpu"),
|
output_device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
@ -136,7 +134,6 @@ def main(
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w2.weight,
|
block.feed_forward.w2.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
ffn_quantize_mode,
|
|
||||||
output_device=torch.device("cpu"),
|
output_device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
cd $(git rev-parse --show-toplevel)
|
cd $(dirname "$(realpath "$0")")
|
||||||
|
|
||||||
MASTER_HOST=$1
|
MASTER_HOST=$1
|
||||||
RUN_ID=$2
|
RUN_ID=$2
|
||||||
|
|
@ -21,7 +21,7 @@ NPROC=$7
|
||||||
|
|
||||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||||
|
|
||||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
|
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models" \
|
||||||
torchrun \
|
torchrun \
|
||||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||||
--rdzv_id=$RUN_ID \
|
--rdzv_id=$RUN_ID \
|
||||||
Loading…
Add table
Add a link
Reference in a new issue