Reduce loading time for non-fp8

This commit is contained in:
Ashwin Bharambe 2024-07-22 19:21:04 -07:00
parent fef679bb34
commit 0e2fc9966a

View file

@ -102,22 +102,31 @@ class Llama:
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
if (
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
)
if fp8:
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
if fp8:
# load on CPU first since if we are doing fp8, we probably don't
# have enough memory on GPU for bf16
model.load_state_dict(state_dict, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if not fp8:
model.load_state_dict(state_dict, strict=False)
if config.quantization:
from .quantization.loader import convert_to_quantized_model