mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Don't load as bf16 on CPU unless fp8 is active
This commit is contained in:
parent
8cd2e4164c
commit
fef679bb34
2 changed files with 13 additions and 4 deletions
|
@ -48,7 +48,10 @@ class Llama:
|
|||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||
|
||||
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
|
||||
if (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
from .quantization.loader import is_fbgemm_available
|
||||
|
||||
if not is_fbgemm_available():
|
||||
|
@ -99,8 +102,13 @@ class Llama:
|
|||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
|
@ -86,7 +87,7 @@ class LlamaModelParallelGenerator:
|
|||
logprobs: bool = False,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=messages,
|
||||
messages=deepcopy(messages),
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue