mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -7,3 +7,5 @@ model_inference_config:
|
||||||
model_parallel_size: 1
|
model_parallel_size: 1
|
||||||
max_seq_len: 2048
|
max_seq_len: 2048
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
|
quantization:
|
||||||
|
type: "fp8"
|
||||||
|
|
|
@ -7,14 +7,7 @@ from hydra.core.config_store import ConfigStore
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from .datatypes import QuantizationConfig
|
||||||
@dataclass
|
|
||||||
class GeneratorArgs:
|
|
||||||
ckpt_dir: str
|
|
||||||
tokenizer_path: str
|
|
||||||
model_parallel_size: Optional[int] = None
|
|
||||||
max_seq_len: int = 2048
|
|
||||||
max_batch_size: int = 4
|
|
||||||
|
|
||||||
|
|
||||||
class ImplType(Enum):
|
class ImplType(Enum):
|
||||||
|
@ -27,6 +20,17 @@ class CheckpointType(Enum):
|
||||||
huggingface = "huggingface"
|
huggingface = "huggingface"
|
||||||
|
|
||||||
|
|
||||||
|
# This enum represents the format in which weights are specified
|
||||||
|
# This does not necessarily always equal what quantization is desired
|
||||||
|
# at runtime since there can be on-the-fly conversions done
|
||||||
|
class CheckpointQuantizationFormat(Enum):
|
||||||
|
# default format
|
||||||
|
bf16 = "bf16"
|
||||||
|
|
||||||
|
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||||
|
fp8_mixed = "fp8_mixed"
|
||||||
|
|
||||||
|
|
||||||
class PytorchCheckpoint(BaseModel):
|
class PytorchCheckpoint(BaseModel):
|
||||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||||
CheckpointType.pytorch.value
|
CheckpointType.pytorch.value
|
||||||
|
@ -34,6 +38,9 @@ class PytorchCheckpoint(BaseModel):
|
||||||
checkpoint_dir: str
|
checkpoint_dir: str
|
||||||
tokenizer_path: str
|
tokenizer_path: str
|
||||||
model_parallel_size: int
|
model_parallel_size: int
|
||||||
|
quantization_format: CheckpointQuantizationFormat = (
|
||||||
|
CheckpointQuantizationFormat.bf16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceCheckpoint(BaseModel):
|
class HuggingFaceCheckpoint(BaseModel):
|
||||||
|
@ -42,6 +49,9 @@ class HuggingFaceCheckpoint(BaseModel):
|
||||||
)
|
)
|
||||||
repo_id: str # or model_name ?
|
repo_id: str # or model_name ?
|
||||||
model_parallel_size: int
|
model_parallel_size: int
|
||||||
|
quantization_format: CheckpointQuantizationFormat = (
|
||||||
|
CheckpointQuantizationFormat.bf16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelCheckpointConfig(BaseModel):
|
class ModelCheckpointConfig(BaseModel):
|
||||||
|
@ -51,10 +61,11 @@ class ModelCheckpointConfig(BaseModel):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this same config will be used when instantiating an inference server naturally
|
|
||||||
class InlineImplConfig(BaseModel):
|
class InlineImplConfig(BaseModel):
|
||||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||||
checkpoint_config: ModelCheckpointConfig
|
checkpoint_config: ModelCheckpointConfig
|
||||||
|
quantization: Optional[QuantizationConfig] = None
|
||||||
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
|
||||||
|
@ -86,6 +97,7 @@ class InlineImplHydraConfig:
|
||||||
model_parallel_size: int
|
model_parallel_size: int
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
quantization: Optional[QuantizationConfig] = None
|
||||||
# TODO: huggingface checkpoint required args
|
# TODO: huggingface checkpoint required args
|
||||||
|
|
||||||
def convert_to_inline_impl_config(self):
|
def convert_to_inline_impl_config(self):
|
||||||
|
@ -99,6 +111,7 @@ class InlineImplHydraConfig:
|
||||||
model_parallel_size=self.model_parallel_size,
|
model_parallel_size=self.model_parallel_size,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
quantization=self.quantization,
|
||||||
max_seq_len=self.max_seq_len,
|
max_seq_len=self.max_seq_len,
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=self.max_batch_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,19 +21,19 @@ class QuantizationType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Fp8QuantizationConfig(BaseModel):
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Bf16QuantizationConfig(BaseModel):
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
quantization_type: Literal[QuantizationType.bf16.value] = (
|
type: Literal[QuantizationType.bf16.value] = (
|
||||||
QuantizationType.bf16.value
|
QuantizationType.bf16.value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
QuantizationConfig = Annotated[
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||||
Field(discriminator="quantization_type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Optional, TypedDict
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -23,6 +23,9 @@ from models.llama3_1.api.model import Transformer
|
||||||
from models.llama3_1.api.tokenizer import Tokenizer
|
from models.llama3_1.api.tokenizer import Tokenizer
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .api.config import CheckpointType, InlineImplConfig
|
||||||
|
from .api.datatypes import QuantizationType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenResult:
|
class TokenResult:
|
||||||
|
@ -31,69 +34,52 @@ class TokenResult:
|
||||||
logprobs: Optional[List[float]] = None
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionPrediction(TypedDict, total=False):
|
|
||||||
generation: str
|
|
||||||
tokens: List[str] # not required
|
|
||||||
logprobs: List[float] # not required
|
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(
|
def build(config: InlineImplConfig):
|
||||||
ckpt_dir: str,
|
|
||||||
tokenizer_path: str,
|
|
||||||
max_seq_len: int,
|
|
||||||
max_batch_size: int,
|
|
||||||
model_parallel_size: Optional[int] = None,
|
|
||||||
seed: int = 1,
|
|
||||||
) -> "Llama":
|
|
||||||
"""
|
"""
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
|
||||||
Args:
|
|
||||||
ckpt_dir (str): Path to the directory containing checkpoint files.
|
|
||||||
tokenizer_path (str): Path to the tokenizer file.
|
|
||||||
max_seq_len (int): Maximum sequence length for input text.
|
|
||||||
max_batch_size (int): Maximum batch size for inference.
|
|
||||||
model_parallel_size (Optional[int], optional): Number of model parallel processes.
|
|
||||||
If not provided, it's determined from the environment. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Llama: An instance of the Llama class with the loaded model and tokenizer.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If there are no checkpoint files in the specified directory,
|
|
||||||
or if the model parallel size does not match the number of checkpoint files.
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
|
checkpoint = config.checkpoint_config.checkpoint
|
||||||
|
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||||
|
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||||
|
|
||||||
|
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
|
||||||
|
from .quantization.loader import is_fbgemm_available
|
||||||
|
|
||||||
|
if not is_fbgemm_available():
|
||||||
|
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
|
||||||
|
model_parallel_size = checkpoint.model_parallel_size
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
if model_parallel_size is None:
|
|
||||||
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
initialize_model_parallel(model_parallel_size)
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
# seed must be the same in all processes
|
# seed must be the same in all processes
|
||||||
torch.manual_seed(seed)
|
if config.torch_seed is not None:
|
||||||
|
torch.manual_seed(config.torch_seed)
|
||||||
|
|
||||||
if local_rank > 0:
|
if local_rank > 0:
|
||||||
sys.stdout = open(os.devnull, "w")
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ckpt_dir = checkpoint.checkpoint_dir
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
assert model_parallel_size == len(
|
assert model_parallel_size == len(
|
||||||
checkpoints
|
checkpoints
|
||||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
@ -103,22 +89,34 @@ class Llama:
|
||||||
params = params["model"]
|
params = params["model"]
|
||||||
|
|
||||||
model_args: ModelArgs = ModelArgs(
|
model_args: ModelArgs = ModelArgs(
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
model_args.vocab_size == tokenizer.n_words
|
model_args.vocab_size == tokenizer.n_words
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
|
||||||
|
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
|
||||||
|
model = Transformer(model_args)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
model = Transformer(model_args)
|
if config.quantization:
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
|
||||||
|
model = convert_to_quantized_model(model, config)
|
||||||
|
else:
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
return Llama(model, tokenizer, model_args)
|
return Llama(model, tokenizer, model_args)
|
||||||
|
|
|
@ -2,10 +2,15 @@ from typing import AsyncGenerator
|
||||||
|
|
||||||
from models.llama3_1.api.datatypes import StopReason
|
from models.llama3_1.api.datatypes import StopReason
|
||||||
|
|
||||||
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig
|
from .api.config import (
|
||||||
|
CheckpointQuantizationFormat,
|
||||||
|
CheckpointType,
|
||||||
|
InlineImplConfig,
|
||||||
|
)
|
||||||
from .api.datatypes import (
|
from .api.datatypes import (
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
QuantizationConfig,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
|
@ -18,33 +23,13 @@ from .api.endpoints import (
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
||||||
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
|
|
||||||
if (
|
|
||||||
config.checkpoint_config.checkpoint.checkpoint_type
|
|
||||||
== CheckpointType.pytorch.value
|
|
||||||
):
|
|
||||||
pt_checkpoint = config.checkpoint_config.checkpoint
|
|
||||||
return GeneratorArgs(
|
|
||||||
ckpt_dir=pt_checkpoint.checkpoint_dir,
|
|
||||||
tokenizer_path=pt_checkpoint.tokenizer_path,
|
|
||||||
model_parallel_size=pt_checkpoint.model_parallel_size,
|
|
||||||
max_seq_len=config.max_seq_len,
|
|
||||||
max_batch_size=config.max_batch_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("HF Checkpoint not supported yet")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInferenceImpl(ModelInference):
|
class ModelInferenceImpl(ModelInference):
|
||||||
|
|
||||||
def __init__(self, config: InlineImplConfig) -> None:
|
def __init__(self, config: InlineImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
generator_args = generator_args_from_config(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator = LlamaModelParallelGenerator(
|
|
||||||
args=generator_args,
|
|
||||||
)
|
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -6,7 +6,7 @@ from models.llama3_1.api.chat_format import ChatFormat
|
||||||
from models.llama3_1.api.datatypes import Message
|
from models.llama3_1.api.datatypes import Message
|
||||||
from models.llama3_1.api.tokenizer import Tokenizer
|
from models.llama3_1.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from .api.config import GeneratorArgs
|
from .api.config import InlineImplConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
@ -35,13 +35,8 @@ class ModelRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(args: GeneratorArgs):
|
def init_model_cb(config: InlineImplConfig):
|
||||||
llama = Llama.build(
|
llama = Llama.build(config)
|
||||||
args.ckpt_dir,
|
|
||||||
args.tokenizer_path,
|
|
||||||
args.max_seq_len,
|
|
||||||
args.max_batch_size,
|
|
||||||
)
|
|
||||||
return ModelRunner(llama)
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,12 +51,13 @@ class LlamaModelParallelGenerator:
|
||||||
clear at the callsite why we need to use a context manager.
|
clear at the callsite why we need to use a context manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: GeneratorArgs):
|
def __init__(self, config: InlineImplConfig):
|
||||||
self.args = args
|
self.config = config
|
||||||
|
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
# while the tool-use loop is going
|
# while the tool-use loop is going
|
||||||
self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path))
|
checkpoint = self.config.checkpoint_config.checkpoint
|
||||||
|
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.__enter__()
|
self.__enter__()
|
||||||
|
@ -70,9 +66,10 @@ class LlamaModelParallelGenerator:
|
||||||
self.__exit__(None, None, None)
|
self.__exit__(None, None, None)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
checkpoint = self.config.checkpoint_config.checkpoint
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
self.args.model_parallel_size,
|
checkpoint.model_parallel_size,
|
||||||
init_model_cb=partial(init_model_cb, self.args),
|
init_model_cb=partial(init_model_cb, self.config),
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
from enum import Enum, unique
|
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -11,20 +10,12 @@ try:
|
||||||
print("Using efficient FP8 operators in FBGEMM.")
|
print("Using efficient FP8 operators in FBGEMM.")
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||||
|
raise
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
@unique
|
|
||||||
class FfnQuantizeMode(Enum):
|
|
||||||
FP8_ROWWISE = "fp8_rowwise"
|
|
||||||
NONE = "none"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8ScaledWeights:
|
class Fp8ScaledWeights:
|
||||||
# TODO: Ugly trick so torch allows us to replace parameters
|
# TODO: Ugly trick so torch allows us to replace parameters
|
||||||
# with our custom Fp8Weights instance. Do this properly.
|
# with our custom Fp8Weights instance. Do this properly.
|
||||||
|
@ -84,7 +75,6 @@ def ffn_swiglu(
|
||||||
def quantize_fp8(
|
def quantize_fp8(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
fp8_activation_scale_ub: float,
|
||||||
mode: Optional[FfnQuantizeMode] = None,
|
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: Optional[torch.device] = None,
|
||||||
) -> Fp8RowwiseWeights:
|
) -> Fp8RowwiseWeights:
|
||||||
"""Quantize [n, k] weight tensor.
|
"""Quantize [n, k] weight tensor.
|
||||||
|
@ -92,22 +82,45 @@ def quantize_fp8(
|
||||||
Args:
|
Args:
|
||||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||||
mode (FfnQuantizeMode): Quantization mode.
|
|
||||||
"""
|
"""
|
||||||
activation_scale_ub = torch.tensor(
|
activation_scale_ub = torch.tensor(
|
||||||
[fp8_activation_scale_ub],
|
[fp8_activation_scale_ub],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
|
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
del w
|
||||||
del w
|
return Fp8RowwiseWeights(
|
||||||
return Fp8RowwiseWeights(
|
weight=wq,
|
||||||
weight=wq,
|
scale=w_scale,
|
||||||
scale=w_scale,
|
shape=wq.shape,
|
||||||
shape=wq.shape,
|
activation_scale_ub=activation_scale_ub,
|
||||||
activation_scale_ub=activation_scale_ub,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def load_fp8(
|
||||||
|
w: Tensor,
|
||||||
|
w_scale: Tensor,
|
||||||
|
fp8_activation_scale_ub: float,
|
||||||
|
) -> Fp8RowwiseWeights:
|
||||||
|
"""Load FP8 [n, k] weight tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w (Tensor): [n, k] input FP8.
|
||||||
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||||
|
"""
|
||||||
|
activation_scale_ub = torch.tensor(
|
||||||
|
[fp8_activation_scale_ub],
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
return Fp8RowwiseWeights(
|
||||||
|
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||||
|
scale=w_scale.to(device="cuda"),
|
||||||
|
shape=w.shape,
|
||||||
|
activation_scale_ub=activation_scale_ub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fc_fp8_dynamic(
|
def fc_fp8_dynamic(
|
||||||
|
|
106
toolchain/inference/quantization/loader.py
Normal file
106
toolchain/inference/quantization/loader.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from models.llama3_1.api.model import Transformer, TransformerBlock
|
||||||
|
|
||||||
|
from toolchain.inference.api.config import (
|
||||||
|
CheckpointQuantizationFormat,
|
||||||
|
InlineImplConfig,
|
||||||
|
)
|
||||||
|
from toolchain.inference.api.datatypes import (
|
||||||
|
QuantizationType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
def is_fbgemm_available() -> bool:
|
||||||
|
try:
|
||||||
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
return True
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_quantized_model(
|
||||||
|
model: Transformer,
|
||||||
|
config: InlineImplConfig,
|
||||||
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
) -> Transformer:
|
||||||
|
if config.quantization.type == QuantizationType.bf16.value:
|
||||||
|
return model
|
||||||
|
|
||||||
|
elif config.quantization.type != QuantizationType.fp8.value:
|
||||||
|
raise ValueError("Only FP8 quantization is supported")
|
||||||
|
|
||||||
|
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = config.checkpoint_config.checkpoint
|
||||||
|
# Move weights to GPU with quantization
|
||||||
|
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||||
|
cprint("Loading fp8 scales...", "yellow")
|
||||||
|
fp8_scales_path = os.path.join(
|
||||||
|
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||||
|
)
|
||||||
|
assert os.path.isfile(
|
||||||
|
fp8_scales_path
|
||||||
|
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||||
|
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||||
|
|
||||||
|
for block in model.layers:
|
||||||
|
if isinstance(block, TransformerBlock):
|
||||||
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
|
continue
|
||||||
|
block.feed_forward.w1.weight = load_fp8(
|
||||||
|
block.feed_forward.w1.weight,
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||||
|
],
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
)
|
||||||
|
block.feed_forward.w3.weight = load_fp8(
|
||||||
|
block.feed_forward.w3.weight,
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||||
|
],
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
)
|
||||||
|
block.feed_forward.w2.weight = load_fp8(
|
||||||
|
block.feed_forward.w2.weight,
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||||
|
],
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
||||||
|
for block in model.layers:
|
||||||
|
if isinstance(block, TransformerBlock):
|
||||||
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
|
continue
|
||||||
|
block.feed_forward.w1.weight = quantize_fp8(
|
||||||
|
block.feed_forward.w1.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
output_device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
block.feed_forward.w3.weight = quantize_fp8(
|
||||||
|
block.feed_forward.w3.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
output_device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
block.feed_forward.w2.weight = quantize_fp8(
|
||||||
|
block.feed_forward.w2.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
output_device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, parameter in model.named_parameters():
|
||||||
|
if not isinstance(parameter, Fp8ScaledWeights):
|
||||||
|
parameter.data = parameter.to(device="cuda")
|
||||||
|
return model
|
|
@ -18,6 +18,12 @@ from fp8.fp8_impls import ffn_swiglu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuantizationArgs:
|
||||||
|
fp8_rowwise: bool = False
|
||||||
|
convert_from_bf16: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs:
|
||||||
dim: int = 4096
|
dim: int = 4096
|
||||||
|
@ -31,6 +37,8 @@ class ModelArgs:
|
||||||
rope_theta: float = 500000
|
rope_theta: float = 500000
|
||||||
use_scaled_rope: bool = False
|
use_scaled_rope: bool = False
|
||||||
|
|
||||||
|
quantization: Optional[QuantizationArgs] = None
|
||||||
|
|
||||||
max_batch_size: int = 32
|
max_batch_size: int = 32
|
||||||
max_seq_len: int = 2048
|
max_seq_len: int = 2048
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
|
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
|
||||||
from hypothesis import given, settings, strategies as st
|
from hypothesis import given, settings, strategies as st
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -33,70 +33,42 @@ class FP8Tests(unittest.TestCase):
|
||||||
UB: float,
|
UB: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
w13 = (
|
w1 = (
|
||||||
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||||
|
)
|
||||||
|
w3 = (
|
||||||
|
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||||
)
|
)
|
||||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
|
|
||||||
x_q = quantize_fp8(x, UB)
|
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||||
w13_q = quantize_fp8(w13, UB)
|
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||||
w2_q = quantize_fp8(w2, UB)
|
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||||
|
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||||
|
|
||||||
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
|
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||||
(B, T, D) = x.shape
|
(B, T, D) = x.shape
|
||||||
(HD_L_2, D_) = w13.shape
|
(HD_L, D_) = w1.shape
|
||||||
assert D_ == D
|
assert D_ == D
|
||||||
HD_L = HD_L_2 // 2
|
|
||||||
|
|
||||||
y = x.view(B * T, D) @ w13.T
|
x1 = x.view(B * T, D) @ w1.T
|
||||||
x1 = y[:, :HD_L]
|
x2 = x.view(B * T, D) @ w3.T
|
||||||
x2 = y[:, HD_L:]
|
|
||||||
|
|
||||||
z = torch.nn.functional.silu(x1) * x2
|
z = torch.nn.functional.silu(x1) * x2
|
||||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||||
|
|
||||||
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
|
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
|
||||||
|
|
||||||
# Fake quant
|
# Fake quant
|
||||||
x = x_q.weight.bfloat16() * x_q.scale
|
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
|
||||||
w13 = w13_q.weight.bfloat16() * w13_q.scale
|
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
|
||||||
w2 = w2_q.weight.bfloat16() * w2_q.scale
|
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
|
||||||
|
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
|
||||||
|
|
||||||
v_ref = ref_ffn(x, w13, w2)
|
v_ref = ref_ffn(x, w1, w3, w2)
|
||||||
|
|
||||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||||
|
|
||||||
@settings(deadline=None)
|
|
||||||
@given(
|
|
||||||
B_T=st.sampled_from([2048, 4096]),
|
|
||||||
D=st.sampled_from([128, 256]),
|
|
||||||
HD_L=st.sampled_from([256, 512]),
|
|
||||||
UB=st.sampled_from([1000, 10000]),
|
|
||||||
)
|
|
||||||
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
|
|
||||||
B_T = 4096
|
|
||||||
D = 256
|
|
||||||
HD_L = 512
|
|
||||||
UB = float(UB)
|
|
||||||
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
|
||||||
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
|
||||||
|
|
||||||
x_q = quantize_fp8(x, UB)
|
|
||||||
wqkv_q = quantize_fp8(wqkv, UB)
|
|
||||||
|
|
||||||
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
|
|
||||||
|
|
||||||
y = attn_linear(x, wqkv_q)
|
|
||||||
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
|
|
||||||
|
|
||||||
# Fake quant
|
|
||||||
x = x_q.weight.bfloat16() * x_q.scale
|
|
||||||
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
|
|
||||||
y_ref = (x @ wqkv.T).to(torch.bfloat16)
|
|
||||||
|
|
||||||
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
|
|
||||||
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue