fp8 inference

This commit is contained in:
Ashwin Bharambe 2024-07-20 23:13:47 -07:00
parent ad62e2e1f3
commit 0746a0f62b
2 changed files with 23 additions and 9 deletions

View file

@ -5,12 +5,13 @@ from typing import AsyncGenerator
import fire import fire
import httpx import httpx
from .api.endpoints import ( from .api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
InstructModel, InstructModel,
ModelInference, ModelInference,
UserMessage,
) )
@ -57,7 +58,7 @@ async def run_main(host: str, port: int):
) )
async for event in client.chat_completion( async for event in client.chat_completion(
ChatCompletionRequest( ChatCompletionRequest(
model=InstructModel.llama3_70b_chat, model=InstructModel.llama3_8b_chat,
messages=[message], messages=[message],
stream=True, stream=True,
) )

View file

@ -2,30 +2,42 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os import os
from typing import Optional from typing import Optional
import torch import torch
from torch import Tensor
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from models.llama3_1.api.model import Transformer, TransformerBlock from models.llama3_1.api.model import Transformer, TransformerBlock
from termcolor import cprint
from toolchain.inference.api.config import ( from toolchain.inference.api.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
InlineImplConfig, InlineImplConfig,
) )
from toolchain.inference.api.datatypes import ( from toolchain.inference.api.datatypes import QuantizationType
QuantizationType,
)
from termcolor import cprint
def is_fbgemm_available() -> bool: def is_fbgemm_available() -> bool:
try: try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True return True
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
return False return False
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer, model: Transformer,
config: InlineImplConfig, config: InlineImplConfig,
@ -39,8 +51,6 @@ def convert_to_quantized_model(
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8 from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization # Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
@ -57,6 +67,8 @@ def convert_to_quantized_model(
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = load_fp8( block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight, block.feed_forward.w1.weight,
fp8_scales[ fp8_scales[
@ -84,6 +96,7 @@ def convert_to_quantized_model(
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = quantize_fp8( block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight, block.feed_forward.w1.weight,
fp8_activation_scale_ub, fp8_activation_scale_ub,