Fix fp8 implementation which had bit-rotten a bit

I only tested with "on-the-fly" bf16 -> fp8 conversion, not the "load
from fp8" codepath.

YAML I tested with:

```
providers:
  - provider_id: quantized
    provider_type: meta-reference-quantized
    config:
      model: Llama3.1-8B-Instruct
      quantization:
        type: fp8
```
This commit is contained in:
Ashwin Bharambe 2024-10-15 13:57:01 -07:00
parent 80ada04f76
commit 09b793c4d6
2 changed files with 10 additions and 7 deletions

View file

@ -138,7 +138,7 @@ class Llama:
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
model = convert_to_quantized_model(model, config, ckpt_dir)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -228,8 +228,7 @@ class Llama:
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(

View file

@ -13,9 +13,10 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import Tensor
@ -39,6 +40,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
@ -49,12 +51,14 @@ def convert_to_quantized_model(
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path