mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 22:09:27 +00:00
Fix fp8 quantization script.
This commit is contained in:
parent
e84d4436b5
commit
511b054b77
2 changed files with 6 additions and 9 deletions
|
|
@ -22,10 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
|
||||
from fp8_impls import quantize_fp8
|
||||
|
||||
from llama.model import ModelArgs, Transformer, TransformerBlock
|
||||
from llama.tokenizer import Tokenizer
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
|
|
@ -36,7 +37,6 @@ def main(
|
|||
max_seq_len: Optional[int] = 512,
|
||||
max_batch_size: Optional[int] = 4,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
seed: int = 1,
|
||||
):
|
||||
|
|
@ -112,7 +112,6 @@ def main(
|
|||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
|
|
@ -124,7 +123,6 @@ def main(
|
|||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
|
|
@ -136,7 +134,6 @@ def main(
|
|||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
cd $(git rev-parse --show-toplevel)
|
||||
cd $(dirname "$(realpath "$0")")
|
||||
|
||||
MASTER_HOST=$1
|
||||
RUN_ID=$2
|
||||
|
|
@ -21,7 +21,7 @@ NPROC=$7
|
|||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
Loading…
Add table
Add a link
Reference in a new issue