From 0e2fc9966a6e2e54f7fe5989c07bd4bb6d5f3786 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 22 Jul 2024 19:21:04 -0700 Subject: [PATCH] Reduce loading time for non-fp8 --- llama_toolchain/inference/generation.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/llama_toolchain/inference/generation.py b/llama_toolchain/inference/generation.py index 2411c69f8..6ca882892 100644 --- a/llama_toolchain/inference/generation.py +++ b/llama_toolchain/inference/generation.py @@ -102,22 +102,31 @@ class Llama: model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - if ( + fp8 = ( config.quantization and config.quantization.type == QuantizationType.fp8.value - ): + ) + + if fp8: # load on CPU in bf16 so that fp8 conversion does not find an # unexpected (fp32, e.g.) datatype torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) + + if fp8: + # load on CPU first since if we are doing fp8, we probably don't + # have enough memory on GPU for bf16 + model.load_state_dict(state_dict, strict=False) if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.HalfTensor) + if not fp8: + model.load_state_dict(state_dict, strict=False) + if config.quantization: from .quantization.loader import convert_to_quantized_model