From fef679bb34d491425ff71c129d3c80eaa49b183c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 22 Jul 2024 19:09:32 -0700 Subject: [PATCH] Don't load as bf16 on CPU unless fp8 is active --- llama_toolchain/inference/generation.py | 14 +++++++++++--- llama_toolchain/inference/model_parallel.py | 3 ++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/llama_toolchain/inference/generation.py b/llama_toolchain/inference/generation.py index 968c0e4d7..2411c69f8 100644 --- a/llama_toolchain/inference/generation.py +++ b/llama_toolchain/inference/generation.py @@ -48,7 +48,10 @@ class Llama: if checkpoint.checkpoint_type != CheckpointType.pytorch.value: raise NotImplementedError("HuggingFace checkpoints not supported yet") - if config.quantization and config.quantization.type == QuantizationType.fp8.value: + if ( + config.quantization + and config.quantization.type == QuantizationType.fp8.value + ): from .quantization.loader import is_fbgemm_available if not is_fbgemm_available(): @@ -99,8 +102,13 @@ class Llama: model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - # load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype - torch.set_default_tensor_type(torch.BFloat16Tensor) + if ( + config.quantization + and config.quantization.type == QuantizationType.fp8.value + ): + # load on CPU in bf16 so that fp8 conversion does not find an + # unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) diff --git a/llama_toolchain/inference/model_parallel.py b/llama_toolchain/inference/model_parallel.py index 2d9737a9c..42b7091c1 100644 --- a/llama_toolchain/inference/model_parallel.py +++ b/llama_toolchain/inference/model_parallel.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional @@ -86,7 +87,7 @@ class LlamaModelParallelGenerator: logprobs: bool = False, ) -> Generator: req_obj = InferenceArgs( - messages=messages, + messages=deepcopy(messages), temperature=temperature, top_p=top_p, max_gen_len=max_gen_len,