mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Fix fp8 implementation which had bit-rotten a bit
I only tested with "on-the-fly" bf16 -> fp8 conversion, not the "load from fp8" codepath. YAML I tested with: ``` providers: - provider_id: quantized provider_type: meta-reference-quantized config: model: Llama3.1-8B-Instruct quantization: type: fp8 ```
This commit is contained in:
parent
80ada04f76
commit
09b793c4d6
2 changed files with 10 additions and 7 deletions
|
@ -138,7 +138,7 @@ class Llama:
|
|||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config)
|
||||
model = convert_to_quantized_model(model, config, ckpt_dir)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
@ -228,8 +228,7 @@ class Llama:
|
|||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(
|
||||
|
|
|
@ -13,9 +13,10 @@ from typing import Optional
|
|||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -39,6 +40,7 @@ def swiglu_wrapper(
|
|||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
|
@ -49,12 +51,14 @@ def convert_to_quantized_model(
|
|||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
llama_model = resolve_model(config.model)
|
||||
assert llama_model is not None, f"Model {config.model} not found"
|
||||
|
||||
# Move weights to GPU with quantization
|
||||
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
cprint("Loading fp8 scales...", "yellow")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
fp8_scales_path
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue