From 09b793c4d6e981abf2140b35ec9560ce90fa73be Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 15 Oct 2024 13:57:01 -0700 Subject: [PATCH] Fix fp8 implementation which had bit-rotten a bit I only tested with "on-the-fly" bf16 -> fp8 conversion, not the "load from fp8" codepath. YAML I tested with: ``` providers: - provider_id: quantized provider_type: meta-reference-quantized config: model: Llama3.1-8B-Instruct quantization: type: fp8 ``` --- .../impls/meta_reference/inference/generation.py | 5 ++--- .../meta_reference/inference/quantization/loader.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 8d94a20d1..9037b9acd 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -138,7 +138,7 @@ class Llama: else: model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) - model = convert_to_quantized_model(model, config) + model = convert_to_quantized_model(model, config, ckpt_dir) else: if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) @@ -228,8 +228,7 @@ class Llama: ignore_index=pad_id, ) - stop_tokens = torch.tensor(self.tokenizer.stop_tokens) - + stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") for cur_pos in range(min_prompt_len, total_len): if is_vision: position_ids = torch.arange( diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index 92b3a6ce3..bd59fe618 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -13,9 +13,10 @@ from typing import Optional import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region - from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock + +from llama_models.sku_list import resolve_model from termcolor import cprint from torch import Tensor @@ -39,6 +40,7 @@ def swiglu_wrapper( def convert_to_quantized_model( model: Transformer, config: MetaReferenceQuantizedInferenceConfig, + checkpoint_dir: str, fp8_activation_scale_ub: Optional[float] = 1200.0, ) -> Transformer: if config.quantization.type == QuantizationType.bf16.value: @@ -49,12 +51,14 @@ def convert_to_quantized_model( from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8 - checkpoint = config.checkpoint_config.checkpoint + llama_model = resolve_model(config.model) + assert llama_model is not None, f"Model {config.model} not found" + # Move weights to GPU with quantization - if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: + if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: cprint("Loading fp8 scales...", "yellow") fp8_scales_path = os.path.join( - checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" + checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" ) assert os.path.isfile( fp8_scales_path