Fix fp8 implementation which had bit-rotten a bit

I only tested with "on-the-fly" bf16 -> fp8 conversion, not the "load
from fp8" codepath.

YAML I tested with:

```
providers:
  - provider_id: quantized
    provider_type: meta-reference-quantized
    config:
      model: Llama3.1-8B-Instruct
      quantization:
        type: fp8
```
This commit is contained in:
Ashwin Bharambe 2024-10-15 13:57:01 -07:00
parent 80ada04f76
commit 09b793c4d6
2 changed files with 10 additions and 7 deletions

View file

@ -138,7 +138,7 @@ class Llama:
else: else:
model = Transformer(model_args) model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config) model = convert_to_quantized_model(model, config, ckpt_dir)
else: else:
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -228,8 +228,7 @@ class Llama:
ignore_index=pad_id, ignore_index=pad_id,
) )
stop_tokens = torch.tensor(self.tokenizer.stop_tokens) stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
for cur_pos in range(min_prompt_len, total_len): for cur_pos in range(min_prompt_len, total_len):
if is_vision: if is_vision:
position_ids = torch.arange( position_ids = torch.arange(

View file

@ -13,9 +13,10 @@ from typing import Optional
import torch import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint from termcolor import cprint
from torch import Tensor from torch import Tensor
@ -39,6 +40,7 @@ def swiglu_wrapper(
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer, model: Transformer,
config: MetaReferenceQuantizedInferenceConfig, config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer: ) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value: if config.quantization.type == QuantizationType.bf16.value:
@ -49,12 +51,14 @@ def convert_to_quantized_model(
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8 from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization # Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow") cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join( fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
) )
assert os.path.isfile( assert os.path.isfile(
fp8_scales_path fp8_scales_path