mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Reduce loading time for non-fp8
This commit is contained in:
parent
fef679bb34
commit
0e2fc9966a
1 changed files with 12 additions and 3 deletions
|
@ -102,15 +102,21 @@ class Llama:
|
|||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
if (
|
||||
fp8 = (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
)
|
||||
|
||||
if fp8:
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
|
||||
if fp8:
|
||||
# load on CPU first since if we are doing fp8, we probably don't
|
||||
# have enough memory on GPU for bf16
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
|
@ -118,6 +124,9 @@ class Llama:
|
|||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
if not fp8:
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if config.quantization:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue