From 4e1105e563a14ed1aba99c031d681a4f2b8a4d2e Mon Sep 17 00:00:00 2001 From: liyunlu0618 <9705880+liyunlu0618@users.noreply.github.com> Date: Thu, 21 Nov 2024 09:15:28 -0800 Subject: [PATCH] Fix fp8 quantization script. (#500) # What does this PR do? Fix fp8 quantization script. ## Test Plan ``` sh run_quantize_checkpoint.sh localhost fp8 /home/yll/fp8_test/ /home/yll/fp8_test/quantized_2 /home/yll/fp8_test/tokenizer.model 1 1 ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. Co-authored-by: Yunlu Li --- .../quantization/scripts/quantize_checkpoint.py | 14 +++++++------- .../scripts/run_quantize_checkpoint.sh | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py index aead05652..891a06296 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py @@ -22,12 +22,16 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8 -from llama.model import ModelArgs, Transformer, TransformerBlock -from llama.tokenizer import Tokenizer +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from torch.nn.parameter import Parameter +from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import ( + quantize_fp8, +) + def main( ckpt_dir: str, @@ -36,7 +40,6 @@ def main( max_seq_len: Optional[int] = 512, max_batch_size: Optional[int] = 4, model_parallel_size: Optional[int] = None, - ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE, fp8_activation_scale_ub: Optional[float] = 1200.0, seed: int = 1, ): @@ -112,7 +115,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w1.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): @@ -124,7 +126,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w3.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): @@ -136,7 +137,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w2.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh index 9282bce2a..84f41d414 100755 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh @@ -9,7 +9,7 @@ set -euo pipefail set -x -cd $(git rev-parse --show-toplevel) +cd $(dirname "$(realpath "$0")") MASTER_HOST=$1 RUN_ID=$2 @@ -21,7 +21,7 @@ NPROC=$7 echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR -NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \ +NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \ torchrun \ --nnodes=$NNODES --nproc_per_node=$NPROC \ --rdzv_id=$RUN_ID \