fp8 inference

This commit is contained in:
Ashwin Bharambe 2024-07-20 23:13:47 -07:00
parent ad62e2e1f3
commit 0746a0f62b
2 changed files with 23 additions and 9 deletions

View file

@ -5,12 +5,13 @@ from typing import AsyncGenerator
import fire
import httpx
from .api.endpoints import (
from .api import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
InstructModel,
ModelInference,
UserMessage,
)
@ -57,7 +58,7 @@ async def run_main(host: str, port: int):
)
async for event in client.chat_completion(
ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
model=InstructModel.llama3_8b_chat,
messages=[message],
stream=True,
)

View file

@ -2,30 +2,42 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from torch import Tensor
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from models.llama3_1.api.model import Transformer, TransformerBlock
from termcolor import cprint
from toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from toolchain.inference.api.datatypes import (
QuantizationType,
)
from toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except (ImportError, ModuleNotFoundError):
return False
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
@ -39,8 +51,6 @@ def convert_to_quantized_model(
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
@ -57,6 +67,8 @@ def convert_to_quantized_model(
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
@ -84,6 +96,7 @@ def convert_to_quantized_model(
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,