mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
fp8 inference
This commit is contained in:
parent
ad62e2e1f3
commit
0746a0f62b
2 changed files with 23 additions and 9 deletions
|
@ -5,12 +5,13 @@ from typing import AsyncGenerator
|
|||
import fire
|
||||
import httpx
|
||||
|
||||
from .api.endpoints import (
|
||||
from .api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
InstructModel,
|
||||
ModelInference,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
|
@ -57,7 +58,7 @@ async def run_main(host: str, port: int):
|
|||
)
|
||||
async for event in client.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model=InstructModel.llama3_70b_chat,
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
messages=[message],
|
||||
stream=True,
|
||||
)
|
||||
|
|
|
@ -2,30 +2,42 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from models.llama3_1.api.model import Transformer, TransformerBlock
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from toolchain.inference.api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
InlineImplConfig,
|
||||
)
|
||||
from toolchain.inference.api.datatypes import (
|
||||
QuantizationType,
|
||||
)
|
||||
from toolchain.inference.api.datatypes import QuantizationType
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
def is_fbgemm_available() -> bool:
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
return True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from .fp8_impls import ffn_swiglu
|
||||
|
||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: InlineImplConfig,
|
||||
|
@ -39,8 +51,6 @@ def convert_to_quantized_model(
|
|||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
|
||||
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
# Move weights to GPU with quantization
|
||||
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
|
@ -57,6 +67,8 @@ def convert_to_quantized_model(
|
|||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
block.feed_forward.w1.weight = load_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_scales[
|
||||
|
@ -84,6 +96,7 @@ def convert_to_quantized_model(
|
|||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
block.feed_forward.w1.weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue