mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Don't load as bf16 on CPU unless fp8 is active
This commit is contained in:
parent
8cd2e4164c
commit
fef679bb34
2 changed files with 13 additions and 4 deletions
|
@ -48,7 +48,10 @@ class Llama:
|
||||||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||||
|
|
||||||
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
|
if (
|
||||||
|
config.quantization
|
||||||
|
and config.quantization.type == QuantizationType.fp8.value
|
||||||
|
):
|
||||||
from .quantization.loader import is_fbgemm_available
|
from .quantization.loader import is_fbgemm_available
|
||||||
|
|
||||||
if not is_fbgemm_available():
|
if not is_fbgemm_available():
|
||||||
|
@ -99,8 +102,13 @@ class Llama:
|
||||||
model_args.vocab_size == tokenizer.n_words
|
model_args.vocab_size == tokenizer.n_words
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
if (
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
config.quantization
|
||||||
|
and config.quantization.type == QuantizationType.fp8.value
|
||||||
|
):
|
||||||
|
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||||
|
# unexpected (fp32, e.g.) datatype
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Generator, List, Optional
|
||||||
|
@ -86,7 +87,7 @@ class LlamaModelParallelGenerator:
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = InferenceArgs(
|
req_obj = InferenceArgs(
|
||||||
messages=messages,
|
messages=deepcopy(messages),
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue