diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py index aead05652..891a06296 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py @@ -22,12 +22,16 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8 -from llama.model import ModelArgs, Transformer, TransformerBlock -from llama.tokenizer import Tokenizer +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from torch.nn.parameter import Parameter +from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import ( + quantize_fp8, +) + def main( ckpt_dir: str, @@ -36,7 +40,6 @@ def main( max_seq_len: Optional[int] = 512, max_batch_size: Optional[int] = 4, model_parallel_size: Optional[int] = None, - ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE, fp8_activation_scale_ub: Optional[float] = 1200.0, seed: int = 1, ): @@ -112,7 +115,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w1.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): @@ -124,7 +126,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w3.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): @@ -136,7 +137,6 @@ def main( fp8_weight = quantize_fp8( block.feed_forward.w2.weight, fp8_activation_scale_ub, - ffn_quantize_mode, output_device=torch.device("cpu"), ) with torch.inference_mode(): diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh index 9282bce2a..84f41d414 100755 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh @@ -9,7 +9,7 @@ set -euo pipefail set -x -cd $(git rev-parse --show-toplevel) +cd $(dirname "$(realpath "$0")") MASTER_HOST=$1 RUN_ID=$2 @@ -21,7 +21,7 @@ NPROC=$7 echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR -NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \ +NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \ torchrun \ --nnodes=$NNODES --nproc_per_node=$NPROC \ --rdzv_id=$RUN_ID \