llama-stack/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py
Ashwin Bharambe 314ee09ae3
chore: move all Llama Stack types from llama-models to llama-stack (#1098)
llama-models should have extremely minimal cruft. Its sole purpose
should be didactic -- show the simplest implementation of the llama
models and document the prompt formats, etc.

This PR is the complement to
https://github.com/meta-llama/llama-models/pull/279

## Test Plan

Ensure all `llama` CLI `model` sub-commands work:

```bash
llama model list
llama model download --model-id ...
llama model prompt-format -m ...
```

Ran tests:
```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/
LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/
LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/
```

Create a fresh venv `uv venv && source .venv/bin/activate` and run
`llama stack build --template fireworks --image-type venv` followed by
`llama stack run together --image-type venv` <-- the server runs

Also checked that the OpenAPI generator can run and there is no change
in the generated files as a result.

```bash
cd docs/openapi_generator
sh run_openapi_generator.sh
```
2025-02-14 09:10:59 -08:00

319 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
import os
from typing import Any, Dict, List, Optional
import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from llama_stack.models.llama.sku_list import resolve_model
from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
fp8_activation_scale_ub,
)
else:
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
device: Device to use.
group_size: Group size for quantization.
precision: Precision of quantization.
scales_precision: Precision of scales.
lora_rank: Rank of LoRA adaptor.
lora_scale: Scale of LoRA adaptor.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=False,
device=None,
# quantization parameters
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
self.adaptor = nn.Sequential()
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
else:
self.adaptor = None
self.lora_scale = None
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
if self.adaptor is not None:
adaptor_out = self.adaptor(input_) * self.lora_scale
return module_out + adaptor_out
return module_out
class Int8WeightEmbedding(torch.nn.Embedding):
"""An embedding layer to load int8 weights.
Args:
num_embeddings: Number of embeddings.
embedding_dim: Embedding dimension.
padding_idx: Padding index.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
device=None,
) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
class Int8WeightLinear(torch.nn.Linear):
"""A linear layer to load int8 weights.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
Note that the weights of embedding and output layers are quantized to int8.
"""
device = None
for module_name, module in model.named_children():
if module_name == "output":
quantized_module = Int8WeightLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif module_name == "tok_embeddings":
quantized_module = Int8WeightEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
bias=False,
group_size=group_size,
lora_rank=lora_rank,
lora_scale=lora_scale,
device=device,
)
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
return model
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
config: MetaReferenceQuantizedInferenceConfig,
) -> Transformer:
"""Convert the model to int4 quantized model."""
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
quantization_args = model_args.quantization_args
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
lora_rank = None
lora_scale = None
else:
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)