From 511b054b7748dfe2794992f0811696017728bbdf Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Wed, 20 Nov 2024 21:32:53 -0800 Subject: [PATCH] Fix fp8 quantization script. --- .../quantization/{scripts => }/quantize_checkpoint.py | 11 ++++------- .../{scripts => }/run_quantize_checkpoint.sh | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) rename llama_stack/providers/inline/inference/meta_reference/quantization/{scripts => }/quantize_checkpoint.py (94%) rename llama_stack/providers/inline/inference/meta_reference/quantization/{scripts => }/run_quantize_checkpoint.sh (82%) diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/quantize_checkpoint.py similarity index 94% rename from llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/quantize_checkpoint.py index aead05652..bddbc00b7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/quantize_checkpoint.py @@ -22,10 +22,11 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8 +from fp8_impls import quantize_fp8 -from llama.model import ModelArgs, Transformer, TransformerBlock -from llama.tokenizer import Tokenizer +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from torch.nn.parameter import Parameter @@ -36,7 +37,6 @@ def main( max_seq_len: Optional[int] = 512, max_batch_size: Optional[int] = 4, model_parallel_size: Optional[int] = None, - ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE, fp8_activation_scale_ub: Optional[float] = 1200.0, seed: int = 1, ): @@ -112,7 +112,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w1.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): @@ -124,7 +123,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w3.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): @@ -136,7 +134,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w2.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/run_quantize_checkpoint.sh similarity index 82% rename from llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh rename to llama_stack/providers/inline/inference/meta_reference/quantization/run_quantize_checkpoint.sh index 9282bce2a..ef1a8ccc9 100755 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/run_quantize_checkpoint.sh @@ -9,7 +9,7 @@ set -euo pipefail set -x -cd $(git rev-parse --show-toplevel) +cd $(dirname "$(realpath "$0")") MASTER_HOST=$1 RUN_ID=$2 @@ -21,7 +21,7 @@ NPROC=$7 echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR -NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \ +NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models" \ torchrun \ --nnodes=$NNODES --nproc_per_node=$NPROC \ --rdzv_id=$RUN_ID \