diff --git a/.gitignore b/.gitignore index a6c204131..897494f21 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ xcuserdata/ Package.resolved *.pte *.ipynb_checkpoints* +.idea .venv/ .idea _build diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 7359c6057..892da13ad 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -172,7 +172,7 @@ async def run_mm_main( ], ) cprint(f"User>{message.content}", "green") - iterator = client.chat_completion( + iterator = await client.chat_completion( model=model, messages=[message], stream=stream, diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4ee01acae..d1ff047b0 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,6 +25,7 @@ class LogProbConfig(BaseModel): class QuantizationType(Enum): bf16 = "bf16" fp8 = "fp8" + int4 = "int4" @json_schema_type @@ -37,8 +38,14 @@ class Bf16QuantizationConfig(BaseModel): type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value +@json_schema_type +class Int4QuantizationConfig(BaseModel): + type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value + scheme: Optional[str] = None + + QuantizationConfig = Annotated[ - Union[Bf16QuantizationConfig, Fp8QuantizationConfig], + Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig], Field(discriminator="type"), ] @@ -219,8 +226,6 @@ class Inference(Protocol): logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/chat_completion") async def chat_completion( self, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index b424a9347..ebce1024b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model - from pydantic import BaseModel from termcolor import cprint @@ -43,7 +42,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) -from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .config import ( + Fp8QuantizationConfig, + Int4QuantizationConfig, + MetaReferenceInferenceConfig, + MetaReferenceQuantizedInferenceConfig, +) def model_checkpoint_dir(model) -> str: @@ -131,18 +135,34 @@ class Llama: ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" if isinstance(config, MetaReferenceQuantizedInferenceConfig): - from .quantization.loader import convert_to_quantized_model - # load on CPU in bf16 so that fp8 conversion does not find an - # unexpected (fp32, e.g.) datatype - torch.set_default_tensor_type(torch.BFloat16Tensor) - if model_args.vision_chunk_size > 0: - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.bfloat16) - else: + if isinstance(config.quantization, Fp8QuantizationConfig): + from .quantization.loader import convert_to_fp8_quantized_model + + # load on CPU in bf16 so that fp8 conversion does not find an + # unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) + if model_args.vision_chunk_size > 0: + model = CrossAttentionTransformer(model_args) + model.setup_cache(model_args.max_batch_size, torch.bfloat16) + else: + model = Transformer(model_args) + model.load_state_dict(state_dict, strict=False) + model = convert_to_fp8_quantized_model(model, config, ckpt_dir) + elif isinstance(config.quantization, Int4QuantizationConfig): + from .quantization.loader import convert_to_int4_quantized_model + + assert ( + config.quantization.scheme is not None + ), "Please specify a quantization scheme." + model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) - model = convert_to_quantized_model(model, config, ckpt_dir) + model = convert_to_int4_quantized_model(model, model_args, config) + model.load_state_dict(state_dict, strict=True) + else: + raise NotImplementedError( + "Currently int4 and fp8 are the only supported quantization methods." + ) else: if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index bd59fe618..e07c9fa3b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -8,19 +8,25 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import os -from typing import Optional +from typing import Any, Dict, List, Optional import torch +from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.datatypes import CheckpointQuantizationFormat -from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock +from llama_models.datatypes import CheckpointQuantizationFormat + +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model from termcolor import cprint -from torch import Tensor +from torch import nn, Tensor + +from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType +from llama_stack.apis.inference.inference import Int4QuantizationConfig from llama_stack.providers.impls.meta_reference.inference.config import ( MetaReferenceQuantizedInferenceConfig, @@ -37,7 +43,7 @@ def swiglu_wrapper( return reduce_from_model_parallel_region(out) -def convert_to_quantized_model( +def convert_to_fp8_quantized_model( model: Transformer, config: MetaReferenceQuantizedInferenceConfig, checkpoint_dir: str, @@ -99,3 +105,241 @@ def convert_to_quantized_model( if not isinstance(parameter, Fp8ScaledWeights): parameter.data = parameter.to(device="cuda") return model + + +class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): + """ + Int8DynActInt4WeightLinear with LoRA adaptor. + + Args: + in_features: Number of input features. + out_features: Number of output features. + bias: Whether to use bias. + device: Device to use. + group_size: Group size for quantization. + precision: Precision of quantization. + scales_precision: Precision of scales. + lora_rank: Rank of LoRA adaptor. + lora_scale: Scale of LoRA adaptor. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=False, + device=None, + # quantization parameters + group_size: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + # LoRA parameters + lora_rank: Optional[int] = None, + lora_scale: Optional[float] = None, + ) -> None: + super().__init__( + in_features, + out_features, + bias=bias, + device=device, + groupsize=group_size, + precision=precision, + scales_precision=scales_precision, + ) + if lora_rank is not None: + assert lora_scale is not None, "Please specify lora scale for LoRA." + # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685 + self.adaptor = nn.Sequential() + self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False)) + self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False)) + self.lora_scale = lora_scale + else: + self.adaptor = None + self.lora_scale = None + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """A hook to load the quantized weights from the state dict.""" + if prefix + "zeros" not in state_dict: + # Zero-point may not be saved in the state dict. In this case, we assume it's zero. + assert prefix + "scales" in state_dict + state_dict[prefix + "zeros"] = torch.zeros_like( + state_dict[prefix + "scales"] + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + module_out = super().forward(input_) + if self.adaptor is not None: + adaptor_out = self.adaptor(input_) * self.lora_scale + return module_out + adaptor_out + return module_out + + +class Int8WeightEmbedding(torch.nn.Embedding): + """An embedding layer to load int8 weights. + + Args: + num_embeddings: Number of embeddings. + embedding_dim: Embedding dimension. + padding_idx: Padding index. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + device=None, + ) -> None: + super().__init__(num_embeddings, embedding_dim, padding_idx, device=device) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """A hook to load the quantized embedding weight and scales from the state dict.""" + weights = state_dict.pop(prefix + "weight") + scales = state_dict.pop(prefix + "scales") + state_dict[prefix + "weight"] = weights * scales + + +class Int8WeightLinear(torch.nn.Linear): + """A linear layer to load int8 weights. + + Args: + in_features: Number of input features. + out_features: Number of output features. + bias: Whether to use bias. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = True, device=None + ) -> None: + super().__init__(in_features, out_features, bias, device=device) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """A hook to load the quantized linear weight and scales from the state dict.""" + weights = state_dict.pop(prefix + "weight") + scales = state_dict.pop(prefix + "scales") + state_dict[prefix + "weight"] = weights * scales + + +def _prepare_model_int4_weight_int8_dynamic_activation( + model: torch.nn.Module, + group_size: int, + lora_rank: Optional[int], + lora_scale: Optional[float], +): + """Prepare the model for int4 weight and int8 dynamic activation quantization. + + Note that the weights of embedding and output layers are quantized to int8. + """ + device = None + for module_name, module in model.named_children(): + if module_name == "output": + quantized_module = Int8WeightLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias, + device=device, + ) + del module + setattr(model, module_name, quantized_module) + elif module_name == "tok_embeddings": + quantized_module = Int8WeightEmbedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + device=device, + ) + del module + setattr(model, module_name, quantized_module) + elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)): + quantized_module = Int8DynActInt4WeightLinearLoRA( + in_features=module.in_features, + out_features=module.out_features, + bias=False, + group_size=group_size, + lora_rank=lora_rank, + lora_scale=lora_scale, + device=device, + ) + del module + setattr(model, module_name, quantized_module) + else: + _prepare_model_int4_weight_int8_dynamic_activation( + module, group_size, lora_rank, lora_scale + ) + + return model + + +def convert_to_int4_quantized_model( + model: Transformer, + model_args: ModelArgs, + config: MetaReferenceQuantizedInferenceConfig, +) -> Transformer: + """Convert the model to int4 quantized model.""" + + quant_config = config.quantization + if not isinstance(quant_config, Int4QuantizationConfig): + raise ValueError("Only int4 quantization is supported") + + if quant_config.type != QuantizationType.int4.value: + raise ValueError("Only int4 quantization is supported") + + if quant_config.scheme != "int4_weight_int8_dynamic_activation": + raise NotImplementedError( + "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." + ) + + if model_args.quantization_args is None: + raise ValueError("'quantization_args' cannot be None. Please specify it.") + + group_size = model_args.quantization_args.group_size + if group_size is None: + raise ValueError( + "'group_size' cannot be None in 'quantization_args'. Please specify it." + ) + + if model_args.lora_args is None: + # Certain quantized models (e.g., SpinQuant) may not have LoRA. + lora_rank = None + lora_scale = None + else: + lora_rank = model_args.lora_args.rank + lora_scale = model_args.lora_args.scale + + _prepare_model_int4_weight_int8_dynamic_activation( + model, group_size, lora_rank, lora_scale + ) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + return model.to(device) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 28555755b..88265f1b4 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]: META_REFERENCE_DEPS + [ "fbgemm-gpu", + "torchao==0.5.0", ] ), module="llama_stack.providers.impls.meta_reference.inference",