mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Reduce loading time for non-fp8
This commit is contained in:
parent
fef679bb34
commit
0e2fc9966a
1 changed files with 12 additions and 3 deletions
|
@ -102,15 +102,21 @@ class Llama:
|
||||||
model_args.vocab_size == tokenizer.n_words
|
model_args.vocab_size == tokenizer.n_words
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
|
||||||
if (
|
fp8 = (
|
||||||
config.quantization
|
config.quantization
|
||||||
and config.quantization.type == QuantizationType.fp8.value
|
and config.quantization.type == QuantizationType.fp8.value
|
||||||
):
|
)
|
||||||
|
|
||||||
|
if fp8:
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||||
# unexpected (fp32, e.g.) datatype
|
# unexpected (fp32, e.g.) datatype
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
|
|
||||||
|
if fp8:
|
||||||
|
# load on CPU first since if we are doing fp8, we probably don't
|
||||||
|
# have enough memory on GPU for bf16
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
|
@ -118,6 +124,9 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
|
if not fp8:
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
if config.quantization:
|
if config.quantization:
|
||||||
from .quantization.loader import convert_to_quantized_model
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue