forked from phoenix-oss/llama-stack-mirror
Fix fp8 quantization script. (#500)
# What does this PR do? Fix fp8 quantization script. ## Test Plan ``` sh run_quantize_checkpoint.sh localhost fp8 /home/yll/fp8_test/ /home/yll/fp8_test/quantized_2 /home/yll/fp8_test/tokenizer.model 1 1 ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. Co-authored-by: Yunlu Li <yll@meta.com>
This commit is contained in:
parent
cf079a22a0
commit
4e1105e563
2 changed files with 9 additions and 9 deletions
|
@ -22,12 +22,16 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
|
|
||||||
|
|
||||||
from llama.model import ModelArgs, Transformer, TransformerBlock
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||||
|
quantize_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
|
@ -36,7 +40,6 @@ def main(
|
||||||
max_seq_len: Optional[int] = 512,
|
max_seq_len: Optional[int] = 512,
|
||||||
max_batch_size: Optional[int] = 4,
|
max_batch_size: Optional[int] = 4,
|
||||||
model_parallel_size: Optional[int] = None,
|
model_parallel_size: Optional[int] = None,
|
||||||
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
|
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
):
|
):
|
||||||
|
@ -112,7 +115,6 @@ def main(
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w1.weight,
|
block.feed_forward.w1.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
ffn_quantize_mode,
|
|
||||||
output_device=torch.device("cpu"),
|
output_device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
@ -124,7 +126,6 @@ def main(
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w3.weight,
|
block.feed_forward.w3.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
ffn_quantize_mode,
|
|
||||||
output_device=torch.device("cpu"),
|
output_device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
@ -136,7 +137,6 @@ def main(
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w2.weight,
|
block.feed_forward.w2.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
ffn_quantize_mode,
|
|
||||||
output_device=torch.device("cpu"),
|
output_device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
cd $(git rev-parse --show-toplevel)
|
cd $(dirname "$(realpath "$0")")
|
||||||
|
|
||||||
MASTER_HOST=$1
|
MASTER_HOST=$1
|
||||||
RUN_ID=$2
|
RUN_ID=$2
|
||||||
|
@ -21,7 +21,7 @@ NPROC=$7
|
||||||
|
|
||||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||||
|
|
||||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
|
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
|
||||||
torchrun \
|
torchrun \
|
||||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||||
--rdzv_id=$RUN_ID \
|
--rdzv_id=$RUN_ID \
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue