forked from phoenix-oss/llama-stack-mirror
New quantized models (#301)
This commit is contained in:
parent
05a8d47b98
commit
7afe51c84d
6 changed files with 292 additions and 21 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,6 +13,7 @@ xcuserdata/
|
|||
Package.resolved
|
||||
*.pte
|
||||
*.ipynb_checkpoints*
|
||||
.idea
|
||||
.venv/
|
||||
.idea
|
||||
_build
|
||||
|
|
|
@ -172,7 +172,7 @@ async def run_mm_main(
|
|||
],
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
|
|
|
@ -25,6 +25,7 @@ class LogProbConfig(BaseModel):
|
|||
class QuantizationType(Enum):
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -37,8 +38,14 @@ class Bf16QuantizationConfig(BaseModel):
|
|||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Int4QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
|
||||
scheme: Optional[str] = None
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -219,8 +226,6 @@ class Inference(Protocol):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
|
@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -43,7 +42,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
MetaReferenceInferenceConfig,
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
|
@ -131,18 +135,34 @@ class Llama:
|
|||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
assert (
|
||||
config.quantization.scheme is not None
|
||||
), "Please specify a quantization scheme."
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config, ckpt_dir)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently int4 and fp8 are the only supported quantization methods."
|
||||
)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
|
|
@ -8,19 +8,25 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.apis.inference.inference import Int4QuantizationConfig
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
|
@ -37,7 +43,7 @@ def swiglu_wrapper(
|
|||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
|
@ -99,3 +105,241 @@ def convert_to_quantized_model(
|
|||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
return model
|
||||
|
||||
|
||||
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||
"""
|
||||
Int8DynActInt4WeightLinear with LoRA adaptor.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
device: Device to use.
|
||||
group_size: Group size for quantization.
|
||||
precision: Precision of quantization.
|
||||
scales_precision: Precision of scales.
|
||||
lora_rank: Rank of LoRA adaptor.
|
||||
lora_scale: Scale of LoRA adaptor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias=False,
|
||||
device=None,
|
||||
# quantization parameters
|
||||
group_size: int = 256,
|
||||
precision: torch.dtype = torch.float32,
|
||||
scales_precision: torch.dtype = torch.float32,
|
||||
# LoRA parameters
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
groupsize=group_size,
|
||||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
self.adaptor = nn.Sequential()
|
||||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||
self.lora_scale = lora_scale
|
||||
else:
|
||||
self.adaptor = None
|
||||
self.lora_scale = None
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized weights from the state dict."""
|
||||
if prefix + "zeros" not in state_dict:
|
||||
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
||||
assert prefix + "scales" in state_dict
|
||||
state_dict[prefix + "zeros"] = torch.zeros_like(
|
||||
state_dict[prefix + "scales"]
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
module_out = super().forward(input_)
|
||||
if self.adaptor is not None:
|
||||
adaptor_out = self.adaptor(input_) * self.lora_scale
|
||||
return module_out + adaptor_out
|
||||
return module_out
|
||||
|
||||
|
||||
class Int8WeightEmbedding(torch.nn.Embedding):
|
||||
"""An embedding layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
num_embeddings: Number of embeddings.
|
||||
embedding_dim: Embedding dimension.
|
||||
padding_idx: Padding index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int,
|
||||
device=None,
|
||||
) -> None:
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
class Int8WeightLinear(torch.nn.Linear):
|
||||
"""A linear layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = True, device=None
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model: torch.nn.Module,
|
||||
group_size: int,
|
||||
lora_rank: Optional[int],
|
||||
lora_scale: Optional[float],
|
||||
):
|
||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||
|
||||
Note that the weights of embedding and output layers are quantized to int8.
|
||||
"""
|
||||
device = None
|
||||
for module_name, module in model.named_children():
|
||||
if module_name == "output":
|
||||
quantized_module = Int8WeightLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif module_name == "tok_embeddings":
|
||||
quantized_module = Int8WeightEmbedding(
|
||||
num_embeddings=module.num_embeddings,
|
||||
embedding_dim=module.embedding_dim,
|
||||
padding_idx=module.padding_idx,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=False,
|
||||
group_size=group_size,
|
||||
lora_rank=lora_rank,
|
||||
lora_scale=lora_scale,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
else:
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
module, group_size, lora_rank, lora_scale
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_to_int4_quantized_model(
|
||||
model: Transformer,
|
||||
model_args: ModelArgs,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
) -> Transformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
quant_config = config.quantization
|
||||
if not isinstance(quant_config, Int4QuantizationConfig):
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
|
||||
if quant_config.type != QuantizationType.int4.value:
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
|
||||
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
||||
)
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"'group_size' cannot be None in 'quantization_args'. Please specify it."
|
||||
)
|
||||
|
||||
if model_args.lora_args is None:
|
||||
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
||||
lora_rank = None
|
||||
lora_scale = None
|
||||
else:
|
||||
lora_rank = model_args.lora_args.rank
|
||||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model, group_size, lora_rank, lora_scale
|
||||
)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return model.to(device)
|
||||
|
|
|
@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
META_REFERENCE_DEPS
|
||||
+ [
|
||||
"fbgemm-gpu",
|
||||
"torchao==0.5.0",
|
||||
]
|
||||
),
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue