mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 17:11:12 +00:00 
			
		
		
		
	# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			316 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			316 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| # type: ignore
 | |
| import os
 | |
| from typing import Any, cast
 | |
| 
 | |
| import torch
 | |
| from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
 | |
| from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
 | |
| from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
 | |
| from torch import Tensor, nn
 | |
| from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
 | |
| 
 | |
| from ...datatypes import QuantizationMode
 | |
| from ...quantize_impls import (
 | |
|     Fp8ScaledWeights,
 | |
|     ffn_swiglu,
 | |
|     load_fp8,
 | |
|     quantize_fp8,
 | |
| )
 | |
| from ..model import Transformer, TransformerBlock
 | |
| from ..multimodal.model import CrossAttentionTransformer
 | |
| 
 | |
| 
 | |
| def swiglu_wrapper(
 | |
|     self,
 | |
|     x: Tensor,
 | |
| ):
 | |
|     out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
 | |
|     return reduce_from_model_parallel_region(out)
 | |
| 
 | |
| 
 | |
| def convert_to_quantized_model(
 | |
|     model: Transformer | CrossAttentionTransformer,
 | |
|     checkpoint_dir: str,
 | |
|     quantization_mode: str | None = None,
 | |
|     fp8_activation_scale_ub: float | None = 1200.0,
 | |
|     device: torch.device | None = None,
 | |
| ) -> Transformer | CrossAttentionTransformer:
 | |
|     if quantization_mode == QuantizationMode.fp8_mixed:
 | |
|         return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
 | |
|     elif quantization_mode == QuantizationMode.int4_mixed:
 | |
|         return convert_to_int4_quantized_model(model, checkpoint_dir, device)
 | |
|     else:
 | |
|         raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
 | |
| 
 | |
| 
 | |
| def convert_to_fp8_quantized_model(
 | |
|     model: Transformer,
 | |
|     checkpoint_dir: str,
 | |
|     fp8_activation_scale_ub: float | None = 1200.0,
 | |
|     device: torch.device | None = None,
 | |
| ) -> Transformer:
 | |
|     # Move weights to GPU with quantization
 | |
|     fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
 | |
|     if os.path.isfile(fp8_scales_path):
 | |
|         print("Loading fp8 scales...")
 | |
|         fp8_scales = torch.load(fp8_scales_path, weights_only=True)
 | |
| 
 | |
|         for _, block in model.named_modules():
 | |
|             if isinstance(block, TransformerBlock):
 | |
|                 if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
 | |
|                     continue
 | |
| 
 | |
|                 block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
 | |
|                 for key in ("w1", "w3", "w2"):
 | |
|                     param = getattr(block.feed_forward, key)
 | |
|                     param.weight = load_fp8(
 | |
|                         param.weight,
 | |
|                         fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
 | |
|                         fp8_activation_scale_ub,
 | |
|                     )
 | |
|     else:
 | |
|         print("Quantizing fp8 weights from bf16...")
 | |
|         for _, block in model.named_modules():
 | |
|             if isinstance(block, TransformerBlock):
 | |
|                 if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
 | |
|                     continue
 | |
|                 block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)  # type: ignore
 | |
|                 for key in ("w1", "w3", "w2"):
 | |
|                     param = getattr(block.feed_forward, key)
 | |
|                     param.weight = quantize_fp8(
 | |
|                         param.weight,
 | |
|                         fp8_activation_scale_ub,
 | |
|                         output_device=device,
 | |
|                     )
 | |
| 
 | |
|     for _, parameter in model.named_parameters():
 | |
|         if not isinstance(parameter, Fp8ScaledWeights):
 | |
|             parameter.data = parameter.to(device=device)
 | |
|     return model
 | |
| 
 | |
| 
 | |
| class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
 | |
|     """
 | |
|     Int8DynActInt4WeightLinear with LoRA adaptor.
 | |
| 
 | |
|     Args:
 | |
|         in_features: Number of input features.
 | |
|         out_features: Number of output features.
 | |
|         bias: Whether to use bias.
 | |
|         device: Device to use.
 | |
|         group_size: Group size for quantization.
 | |
|         precision: Precision of quantization.
 | |
|         scales_precision: Precision of scales.
 | |
|         lora_rank: Rank of LoRA adaptor.
 | |
|         lora_scale: Scale of LoRA adaptor.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_features: int,
 | |
|         out_features: int,
 | |
|         bias=False,
 | |
|         device=None,
 | |
|         # quantization parameters
 | |
|         group_size: int = 256,
 | |
|         precision: torch.dtype = torch.float32,
 | |
|         scales_precision: torch.dtype = torch.float32,
 | |
|         # LoRA parameters
 | |
|         lora_rank: int | None = None,
 | |
|         lora_scale: float | None = None,
 | |
|     ) -> None:
 | |
|         super().__init__(
 | |
|             in_features,
 | |
|             out_features,
 | |
|             bias=bias,
 | |
|             device=device,
 | |
|             groupsize=group_size,
 | |
|             precision=precision,
 | |
|             scales_precision=scales_precision,
 | |
|         )
 | |
|         self.lora_scale: float | None = None
 | |
|         self.adaptor: nn.Sequential | None = None
 | |
|         if lora_rank is not None:
 | |
|             assert lora_scale is not None, "Please specify lora scale for LoRA."
 | |
|             # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
 | |
|             self.adaptor = nn.Sequential()
 | |
|             self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
 | |
|             self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
 | |
|             self.lora_scale = lora_scale
 | |
|         self._register_load_state_dict_pre_hook(self.load_hook)
 | |
| 
 | |
|     def load_hook(
 | |
|         self,
 | |
|         state_dict: dict[str, Any],
 | |
|         prefix: str,
 | |
|         local_metadata: dict[str, Any],
 | |
|         strict: bool,
 | |
|         missing_keys: list[str],
 | |
|         unexpected_keys: list[str],
 | |
|         error_msgs: list[str],
 | |
|     ) -> None:
 | |
|         """A hook to load the quantized weights from the state dict."""
 | |
|         if prefix + "zeros" not in state_dict:
 | |
|             # Zero-point may not be saved in the state dict. In this case, we assume it's zero.
 | |
|             assert prefix + "scales" in state_dict
 | |
|             state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
 | |
| 
 | |
|     def forward(self, input_: torch.Tensor) -> torch.Tensor:
 | |
|         module_out = super().forward(input_)
 | |
|         if self.adaptor is not None:
 | |
|             adaptor_out = self.adaptor(input_) * self.lora_scale
 | |
|             return module_out + adaptor_out
 | |
|         return module_out
 | |
| 
 | |
| 
 | |
| class Int8WeightEmbedding(torch.nn.Embedding):
 | |
|     """An embedding layer to load int8 weights.
 | |
| 
 | |
|     Args:
 | |
|         num_embeddings: Number of embeddings.
 | |
|         embedding_dim: Embedding dimension.
 | |
|         padding_idx: Padding index.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_embeddings: int,
 | |
|         embedding_dim: int,
 | |
|         padding_idx: int,
 | |
|         device=None,
 | |
|     ) -> None:
 | |
|         super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
 | |
| 
 | |
|         self._register_load_state_dict_pre_hook(self.load_hook)
 | |
| 
 | |
|     def load_hook(
 | |
|         self,
 | |
|         state_dict: dict[str, Any],
 | |
|         prefix: str,
 | |
|         local_metadata: dict[str, Any],
 | |
|         strict: bool,
 | |
|         missing_keys: list[str],
 | |
|         unexpected_keys: list[str],
 | |
|         error_msgs: list[str],
 | |
|     ) -> None:
 | |
|         """A hook to load the quantized embedding weight and scales from the state dict."""
 | |
|         weights = state_dict.pop(prefix + "weight")
 | |
|         scales = state_dict.pop(prefix + "scales")
 | |
|         state_dict[prefix + "weight"] = weights * scales
 | |
| 
 | |
| 
 | |
| class Int8WeightLinear(torch.nn.Linear):
 | |
|     """A linear layer to load int8 weights.
 | |
| 
 | |
|     Args:
 | |
|         in_features: Number of input features.
 | |
|         out_features: Number of output features.
 | |
|         bias: Whether to use bias.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
 | |
|         super().__init__(in_features, out_features, bias, device=device)
 | |
| 
 | |
|         self._register_load_state_dict_pre_hook(self.load_hook)
 | |
| 
 | |
|     def load_hook(
 | |
|         self,
 | |
|         state_dict: dict[str, Any],
 | |
|         prefix: str,
 | |
|         local_metadata: dict[str, Any],
 | |
|         strict: bool,
 | |
|         missing_keys: list[str],
 | |
|         unexpected_keys: list[str],
 | |
|         error_msgs: list[str],
 | |
|     ) -> None:
 | |
|         """A hook to load the quantized linear weight and scales from the state dict."""
 | |
|         weights = state_dict.pop(prefix + "weight")
 | |
|         scales = state_dict.pop(prefix + "scales")
 | |
|         state_dict[prefix + "weight"] = weights * scales
 | |
| 
 | |
| 
 | |
| def _prepare_model_int4_weight_int8_dynamic_activation(
 | |
|     model: torch.nn.Module,
 | |
|     group_size: int,
 | |
|     lora_rank: int | None,
 | |
|     lora_scale: float | None,
 | |
| ):
 | |
|     """Prepare the model for int4 weight and int8 dynamic activation quantization.
 | |
| 
 | |
|     Note that the weights of embedding and output layers are quantized to int8.
 | |
|     """
 | |
|     device = None
 | |
|     for module_name, module in model.named_children():
 | |
|         if module_name == "output":
 | |
|             quantized_module = Int8WeightLinear(
 | |
|                 in_features=module.in_features,
 | |
|                 out_features=module.out_features,
 | |
|                 bias=module.bias,
 | |
|                 device=device,
 | |
|             )
 | |
|             del module
 | |
|             setattr(model, module_name, quantized_module)
 | |
|         elif module_name == "tok_embeddings":
 | |
|             quantized_module = Int8WeightEmbedding(
 | |
|                 num_embeddings=module.num_embeddings,
 | |
|                 embedding_dim=module.embedding_dim,
 | |
|                 padding_idx=module.padding_idx,
 | |
|                 device=device,
 | |
|             )
 | |
|             del module
 | |
|             setattr(model, module_name, quantized_module)
 | |
|         elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
 | |
|             quantized_module = Int8DynActInt4WeightLinearLoRA(
 | |
|                 in_features=module.in_features,
 | |
|                 out_features=module.out_features,
 | |
|                 bias=False,
 | |
|                 group_size=group_size,
 | |
|                 lora_rank=lora_rank,
 | |
|                 lora_scale=lora_scale,
 | |
|                 device=device,
 | |
|             )
 | |
|             del module
 | |
|             setattr(model, module_name, quantized_module)
 | |
|         else:
 | |
|             _prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
 | |
| 
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def convert_to_int4_quantized_model(
 | |
|     model: Transformer | CrossAttentionTransformer,
 | |
|     checkpoint_dir: str,
 | |
|     device: torch.device | None = None,
 | |
| ) -> Transformer | CrossAttentionTransformer:
 | |
|     """Convert the model to int4 quantized model."""
 | |
|     model_args = model.params
 | |
|     assert model_args.quantization_args is not None, "Quantization args must be specified."
 | |
|     quantization_args = model_args.quantization_args
 | |
|     if quantization_args.scheme is None:
 | |
|         raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
 | |
| 
 | |
|     if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
 | |
|         raise NotImplementedError(
 | |
|             "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
 | |
|         )
 | |
| 
 | |
|     group_size = model_args.quantization_args.group_size
 | |
|     if group_size is None:
 | |
|         raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
 | |
| 
 | |
|     if model_args.lora_args is None:
 | |
|         # Certain quantized models (e.g., SpinQuant) may not have LoRA.
 | |
|         lora_rank = None
 | |
|         lora_scale = None
 | |
|     else:
 | |
|         lora_rank = model_args.lora_args.rank
 | |
|         lora_scale = model_args.lora_args.scale
 | |
| 
 | |
|     _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
 | |
|     return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
 |