diff --git a/toolchain/inference/client.py b/toolchain/inference/client.py index 0cb14e4c7..a4d2b641f 100644 --- a/toolchain/inference/client.py +++ b/toolchain/inference/client.py @@ -5,12 +5,13 @@ from typing import AsyncGenerator import fire import httpx -from .api.endpoints import ( +from .api import ( ChatCompletionRequest, ChatCompletionResponseStreamChunk, CompletionRequest, InstructModel, ModelInference, + UserMessage, ) @@ -57,7 +58,7 @@ async def run_main(host: str, port: int): ) async for event in client.chat_completion( ChatCompletionRequest( - model=InstructModel.llama3_70b_chat, + model=InstructModel.llama3_8b_chat, messages=[message], stream=True, ) diff --git a/toolchain/inference/quantization/loader.py b/toolchain/inference/quantization/loader.py index fde77685f..f2a162b40 100644 --- a/toolchain/inference/quantization/loader.py +++ b/toolchain/inference/quantization/loader.py @@ -2,30 +2,42 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import os - from typing import Optional import torch +from torch import Tensor + +from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from models.llama3_1.api.model import Transformer, TransformerBlock +from termcolor import cprint + from toolchain.inference.api.config import ( CheckpointQuantizationFormat, InlineImplConfig, ) -from toolchain.inference.api.datatypes import ( - QuantizationType, -) +from toolchain.inference.api.datatypes import QuantizationType -from termcolor import cprint def is_fbgemm_available() -> bool: try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 + return True except (ImportError, ModuleNotFoundError): return False +def swiglu_wrapper( + self, + x: Tensor, +): + from .fp8_impls import ffn_swiglu + + out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight) + return reduce_from_model_parallel_region(out) + + def convert_to_quantized_model( model: Transformer, config: InlineImplConfig, @@ -39,8 +51,6 @@ def convert_to_quantized_model( from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8 - - checkpoint = config.checkpoint_config.checkpoint # Move weights to GPU with quantization if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: @@ -57,6 +67,8 @@ def convert_to_quantized_model( if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): continue + + block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) block.feed_forward.w1.weight = load_fp8( block.feed_forward.w1.weight, fp8_scales[ @@ -84,6 +96,7 @@ def convert_to_quantized_model( if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): continue + block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) block.feed_forward.w1.weight = quantize_fp8( block.feed_forward.w1.weight, fp8_activation_scale_ub,