Don't load as bf16 on CPU unless fp8 is active

This commit is contained in:
Ashwin Bharambe 2024-07-22 19:09:32 -07:00
parent 8cd2e4164c
commit fef679bb34
2 changed files with 13 additions and 4 deletions

View file

@ -48,7 +48,10 @@ class Llama:
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
@ -99,8 +102,13 @@ class Llama:
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)

View file

@ -1,3 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
@ -86,7 +87,7 @@ class LlamaModelParallelGenerator:
logprobs: bool = False,
) -> Generator:
req_obj = InferenceArgs(
messages=messages,
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,