make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -7,3 +7,5 @@ model_inference_config:
model_parallel_size: 1
max_seq_len: 2048
max_batch_size: 1
quantization:
type: "fp8"

View file

@ -7,14 +7,7 @@ from hydra.core.config_store import ConfigStore
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@dataclass
class GeneratorArgs:
ckpt_dir: str
tokenizer_path: str
model_parallel_size: Optional[int] = None
max_seq_len: int = 2048
max_batch_size: int = 4
from .datatypes import QuantizationConfig
class ImplType(Enum):
@ -27,6 +20,17 @@ class CheckpointType(Enum):
huggingface = "huggingface"
# This enum represents the format in which weights are specified
# This does not necessarily always equal what quantization is desired
# at runtime since there can be on-the-fly conversions done
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8_mixed"
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
@ -34,6 +38,9 @@ class PytorchCheckpoint(BaseModel):
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
class HuggingFaceCheckpoint(BaseModel):
@ -42,6 +49,9 @@ class HuggingFaceCheckpoint(BaseModel):
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
class ModelCheckpointConfig(BaseModel):
@ -51,10 +61,11 @@ class ModelCheckpointConfig(BaseModel):
]
# NOTE: this same config will be used when instantiating an inference server naturally
class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int
max_batch_size: int = 1
@ -86,6 +97,7 @@ class InlineImplHydraConfig:
model_parallel_size: int
max_seq_len: int
max_batch_size: int = 1
quantization: Optional[QuantizationConfig] = None
# TODO: huggingface checkpoint required args
def convert_to_inline_impl_config(self):
@ -99,6 +111,7 @@ class InlineImplHydraConfig:
model_parallel_size=self.model_parallel_size,
)
),
quantization=self.quantization,
max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size,
)

View file

@ -21,19 +21,19 @@ class QuantizationType(Enum):
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.bf16.value] = (
type: Literal[QuantizationType.bf16.value] = (
QuantizationType.bf16.value
)
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="quantization_type"),
Field(discriminator="type"),
]

View file

@ -7,7 +7,7 @@ import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional, TypedDict
from typing import Generator, List, Optional
import torch
import torch.nn.functional as F
@ -23,6 +23,9 @@ from models.llama3_1.api.model import Transformer
from models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig
from .api.datatypes import QuantizationType
@dataclass
class TokenResult:
@ -31,69 +34,52 @@ class TokenResult:
logprobs: Optional[List[float]] = None
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
seed: int = 1,
) -> "Llama":
def build(config: InlineImplConfig):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Args:
ckpt_dir (str): Path to the directory containing checkpoint files.
tokenizer_path (str): Path to the tokenizer file.
max_seq_len (int): Maximum sequence length for input text.
max_batch_size (int): Maximum batch size for inference.
model_parallel_size (Optional[int], optional): Number of model parallel processes.
If not provided, it's determined from the environment. Defaults to None.
Returns:
Llama: An instance of the Llama class with the loaded model and tokenizer.
Raises:
AssertionError: If there are no checkpoint files in the specified directory,
or if the model parallel size does not match the number of checkpoint files.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
checkpoint = config.checkpoint_config.checkpoint
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -103,22 +89,34 @@ class Llama:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if config.quantization:
from .quantization.loader import convert_to_quantized_model
model = convert_to_quantized_model(model, config)
else:
model = model.to("cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)

View file

@ -2,10 +2,15 @@ from typing import AsyncGenerator
from models.llama3_1.api.datatypes import StopReason
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig
from .api.config import (
CheckpointQuantizationFormat,
CheckpointType,
InlineImplConfig,
)
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
QuantizationConfig,
ToolCallDelta,
ToolCallParseStatus,
)
@ -18,33 +23,13 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
if (
config.checkpoint_config.checkpoint.checkpoint_type
== CheckpointType.pytorch.value
):
pt_checkpoint = config.checkpoint_config.checkpoint
return GeneratorArgs(
ckpt_dir=pt_checkpoint.checkpoint_dir,
tokenizer_path=pt_checkpoint.tokenizer_path,
model_parallel_size=pt_checkpoint.model_parallel_size,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
class ModelInferenceImpl(ModelInference):
def __init__(self, config: InlineImplConfig) -> None:
self.config = config
async def initialize(self) -> None:
generator_args = generator_args_from_config(self.config)
self.generator = LlamaModelParallelGenerator(
args=generator_args,
)
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def shutdown(self) -> None:

View file

@ -6,7 +6,7 @@ from models.llama3_1.api.chat_format import ChatFormat
from models.llama3_1.api.datatypes import Message
from models.llama3_1.api.tokenizer import Tokenizer
from .api.config import GeneratorArgs
from .api.config import InlineImplConfig
from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup
@ -35,13 +35,8 @@ class ModelRunner:
)
def init_model_cb(args: GeneratorArgs):
llama = Llama.build(
args.ckpt_dir,
args.tokenizer_path,
args.max_seq_len,
args.max_batch_size,
)
def init_model_cb(config: InlineImplConfig):
llama = Llama.build(config)
return ModelRunner(llama)
@ -56,12 +51,13 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, args: GeneratorArgs):
self.args = args
def __init__(self, config: InlineImplConfig):
self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path))
checkpoint = self.config.checkpoint_config.checkpoint
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
def start(self):
self.__enter__()
@ -70,9 +66,10 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None)
def __enter__(self):
checkpoint = self.config.checkpoint_config.checkpoint
self.group = ModelParallelProcessGroup(
self.args.model_parallel_size,
init_model_cb=partial(init_model_cb, self.args),
checkpoint.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()
return self

View file

@ -2,7 +2,6 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
from enum import Enum, unique
from typing import Optional, Type
try:
@ -11,20 +10,12 @@ try:
print("Using efficient FP8 operators in FBGEMM.")
except (ImportError, ModuleNotFoundError):
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
from torch import nn, Tensor
@unique
class FfnQuantizeMode(Enum):
FP8_ROWWISE = "fp8_rowwise"
NONE = "none"
def __str__(self) -> str:
return self.value
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@ -84,7 +75,6 @@ def ffn_swiglu(
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
mode: Optional[FfnQuantizeMode] = None,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
@ -92,14 +82,12 @@ def quantize_fp8(
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
mode (FfnQuantizeMode): Quantization mode.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
@ -110,6 +98,31 @@ def quantize_fp8(
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(
x: Tensor,
w: Fp8RowwiseWeights,

View file

@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from models.llama3_1.api.model import Transformer, TransformerBlock
from toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from toolchain.inference.api.datatypes import (
QuantizationType,
)
from termcolor import cprint
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except (ImportError, ModuleNotFoundError):
return False
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w3.weight = load_fp8(
block.feed_forward.w3.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w2.weight = load_fp8(
block.feed_forward.w2.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w3.weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w2.weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model

View file

@ -18,6 +18,12 @@ from fp8.fp8_impls import ffn_swiglu
from torch import nn
@dataclass
class QuantizationArgs:
fp8_rowwise: bool = False
convert_from_bf16: bool = False
@dataclass
class ModelArgs:
dim: int = 4096
@ -31,6 +37,8 @@ class ModelArgs:
rope_theta: float = 500000
use_scaled_rope: bool = False
quantization: Optional[QuantizationArgs] = None
max_batch_size: int = 32
max_seq_len: int = 2048

View file

@ -5,7 +5,7 @@ import unittest
import torch
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
from hypothesis import given, settings, strategies as st
from torch import Tensor
@ -33,70 +33,42 @@ class FP8Tests(unittest.TestCase):
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w13 = (
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w1 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w3 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB)
w13_q = quantize_fp8(w13, UB)
w2_q = quantize_fp8(w2, UB)
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape
(HD_L_2, D_) = w13.shape
(HD_L, D_) = w1.shape
assert D_ == D
HD_L = HD_L_2 // 2
y = x.view(B * T, D) @ w13.T
x1 = y[:, :HD_L]
x2 = y[:, HD_L:]
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
w13 = w13_q.weight.bfloat16() * w13_q.scale
w2 = w2_q.weight.bfloat16() * w2_q.scale
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w13, w2)
v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
@settings(deadline=None)
@given(
B_T=st.sampled_from([2048, 4096]),
D=st.sampled_from([128, 256]),
HD_L=st.sampled_from([256, 512]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
B_T = 4096
D = 256
HD_L = 512
UB = float(UB)
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
x_q = quantize_fp8(x, UB)
wqkv_q = quantize_fp8(wqkv, UB)
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
y = attn_linear(x, wqkv_q)
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
y_ref = (x @ wqkv.T).to(torch.bfloat16)
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
if __name__ == "__main__":
unittest.main()