mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fp8 inference
This commit is contained in:
parent
ad62e2e1f3
commit
0746a0f62b
2 changed files with 23 additions and 9 deletions
|
@ -5,12 +5,13 @@ from typing import AsyncGenerator
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .api.endpoints import (
|
from .api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
InstructModel,
|
InstructModel,
|
||||||
ModelInference,
|
ModelInference,
|
||||||
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ async def run_main(host: str, port: int):
|
||||||
)
|
)
|
||||||
async for event in client.chat_completion(
|
async for event in client.chat_completion(
|
||||||
ChatCompletionRequest(
|
ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_70b_chat,
|
model=InstructModel.llama3_8b_chat,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,30 +2,42 @@
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from models.llama3_1.api.model import Transformer, TransformerBlock
|
from models.llama3_1.api.model import Transformer, TransformerBlock
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from toolchain.inference.api.config import (
|
from toolchain.inference.api.config import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
InlineImplConfig,
|
InlineImplConfig,
|
||||||
)
|
)
|
||||||
from toolchain.inference.api.datatypes import (
|
from toolchain.inference.api.datatypes import QuantizationType
|
||||||
QuantizationType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
def is_fbgemm_available() -> bool:
|
def is_fbgemm_available() -> bool:
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def swiglu_wrapper(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
):
|
||||||
|
from .fp8_impls import ffn_swiglu
|
||||||
|
|
||||||
|
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||||
|
return reduce_from_model_parallel_region(out)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_quantized_model(
|
def convert_to_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
config: InlineImplConfig,
|
config: InlineImplConfig,
|
||||||
|
@ -39,8 +51,6 @@ def convert_to_quantized_model(
|
||||||
|
|
||||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = config.checkpoint_config.checkpoint
|
checkpoint = config.checkpoint_config.checkpoint
|
||||||
# Move weights to GPU with quantization
|
# Move weights to GPU with quantization
|
||||||
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||||
|
@ -57,6 +67,8 @@ def convert_to_quantized_model(
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||||
block.feed_forward.w1.weight = load_fp8(
|
block.feed_forward.w1.weight = load_fp8(
|
||||||
block.feed_forward.w1.weight,
|
block.feed_forward.w1.weight,
|
||||||
fp8_scales[
|
fp8_scales[
|
||||||
|
@ -84,6 +96,7 @@ def convert_to_quantized_model(
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
continue
|
continue
|
||||||
|
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||||
block.feed_forward.w1.weight = quantize_fp8(
|
block.feed_forward.w1.weight = quantize_fp8(
|
||||||
block.feed_forward.w1.weight,
|
block.feed_forward.w1.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue