Don't load as bf16 on CPU unless fp8 is active

This commit is contained in:
Ashwin Bharambe 2024-07-22 19:09:32 -07:00
parent 8cd2e4164c
commit fef679bb34
2 changed files with 13 additions and 4 deletions

View file

@ -48,7 +48,10 @@ class Llama:
if checkpoint.checkpoint_type != CheckpointType.pytorch.value: if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet") raise NotImplementedError("HuggingFace checkpoints not supported yet")
if config.quantization and config.quantization.type == QuantizationType.fp8.value: if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available(): if not is_fbgemm_available():
@ -99,7 +102,12 @@ class Llama:
model_args.vocab_size == tokenizer.n_words model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args) model = Transformer(model_args)

View file

@ -1,3 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Generator, List, Optional from typing import Generator, List, Optional
@ -86,7 +87,7 @@ class LlamaModelParallelGenerator:
logprobs: bool = False, logprobs: bool = False,
) -> Generator: ) -> Generator:
req_obj = InferenceArgs( req_obj = InferenceArgs(
messages=messages, messages=deepcopy(messages),
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,