diff --git a/llama_toolchain/inference/generation.py b/llama_toolchain/inference/generation.py index 2411c69f8..6ca882892 100644 --- a/llama_toolchain/inference/generation.py +++ b/llama_toolchain/inference/generation.py @@ -102,22 +102,31 @@ class Llama: model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - if ( + fp8 = ( config.quantization and config.quantization.type == QuantizationType.fp8.value - ): + ) + + if fp8: # load on CPU in bf16 so that fp8 conversion does not find an # unexpected (fp32, e.g.) datatype torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) + + if fp8: + # load on CPU first since if we are doing fp8, we probably don't + # have enough memory on GPU for bf16 + model.load_state_dict(state_dict, strict=False) if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.HalfTensor) + if not fp8: + model.load_state_dict(state_dict, strict=False) + if config.quantization: from .quantization.loader import convert_to_quantized_model