From b75e671c3be458389ba0ce1547c225f35baff729 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 22 Jan 2025 14:52:05 -0800 Subject: [PATCH] temp commit --- .../torchtune/common/checkpointer.py | 308 ++++++++++++++---- .../inline/post_training/torchtune/config.py | 1 + .../recipes/lora_finetuning_single_device.py | 38 ++- 3 files changed, 261 insertions(+), 86 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 359fc43ca..882deec4f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -12,7 +12,23 @@ from typing import Any, Dict, List import torch from torchtune import training from torchtune.models import convert_weights -from torchtune.training.checkpointing._utils import ModelType, safe_torch_load +from torchtune.training.checkpointing._utils import ( + ADAPTER_CONFIG_FNAME, + ADAPTER_MODEL_FNAME, + check_outdir_not_in_ckptdir, + copy_files, + get_adapter_checkpoint_path, + get_model_checkpoint_path, + get_recipe_checkpoint_path, + ModelType, + RECIPE_STATE_DIRNAME, + REPO_ID_FNAME, + safe_torch_load, + SAFETENSOR_INDEX_FNAME, + SHARD_FNAME, + SUFFIXES_TO_NOT_COPY, + TORCH_INDEX_FNAME, +) from torchtune.utils._logging import get_logger logger = get_logger("DEBUG") @@ -81,83 +97,239 @@ class TorchtuneCheckpointer: state_dict: Dict[str, Any], epoch: int, adapter_only: bool = False, + checkpoint_format: str = "meta", ) -> str: model_file_path = ( Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" ) + if format == "meta": + model_file_path.mkdir(parents=True, exist_ok=True) - model_file_path.mkdir(parents=True, exist_ok=True) - - # copy the related files for inference - source_path = Path.joinpath(self._checkpoint_dir, "params.json") - if source_path.exists(): - shutil.copy( - source_path, - Path.joinpath(model_file_path, "params.json"), - ) - source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model") - if source_path.exists(): - shutil.copy( - source_path, - Path.joinpath(model_file_path, "tokenizer.model"), - ) - source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json") - if source_path.exists(): - shutil.copy( - source_path, - Path.joinpath(model_file_path, "orig_params.json"), - ) - - if not adapter_only: - model_state_dict = state_dict[training.MODEL_KEY] - if self._model_type == ModelType.LLAMA3_VISION: - from torchtune.models.llama3_2_vision._convert_weights import ( - llama3_vision_tune_to_meta, + # copy the related files for inference + source_path = Path.joinpath(self._checkpoint_dir, "params.json") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "params.json"), + ) + source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "tokenizer.model"), + ) + source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "orig_params.json"), ) - state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( - model_state_dict + if not adapter_only: + model_state_dict = state_dict[training.MODEL_KEY] + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_tune_to_meta, + ) + + state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( + model_state_dict + ) + else: + # llama3_2 has tied weights, so we need to add the output.weight key + if ( + self._model_type == ModelType.LLAMA3_2 + and "output.weight" not in model_state_dict + ): + model_state_dict["output.weight"] = model_state_dict[ + "tok_embeddings.weight" + ] + + state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( + model_state_dict + ) + + model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") + + torch.save(state_dict[training.MODEL_KEY], model_file_name) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB " + f"saved to {model_file_name}" + ) + + if training.ADAPTER_KEY in state_dict: + adapter_file_path = model_file_path / "adapter" + adapter_file_path.mkdir(parents=True, exist_ok=True) + adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth") + torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB " + f"saved to {adapter_file_name}" + ) + + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + print("model_file_path", str(model_file_path)) + elif format == "hf": + # the config.json file contains model params needed for state dict conversion + config = json.loads( + Path.joinpath(self._checkpoint_dir, "config.json").read_text() + ) + if not adapter_only: + state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( + state_dict[training.MODEL_KEY], + num_heads=config["num_attention_heads"], + num_kv_heads=config["num_key_value_heads"], + dim=config["hidden_size"], + head_dim=config.get("head_dim", None), + ) + + # split the state_dict into separate dicts, one for each output checkpoint file + # e.g. split_state_dicts= { + # "0001": {"key1": tensor1, "key2": tensor2}, + # "0002": {"key3": tensor3} + # } + split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {} + total_size = 0 + for key, weight in state_dict[training.MODEL_KEY].items(): + cpt_idx = self._weight_map[key] + + # initialize dict + if cpt_idx not in split_state_dicts: + split_state_dicts[cpt_idx] = {} + + split_state_dicts[cpt_idx].update({key: weight}) + total_size += weight.numel() * weight.element_size() + + # write the partitioned state dicts to the right checkpoint file + # e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc + num_shards = len(split_state_dicts) + map_original_name_to_new_name = {} + for cpt_idx, model_state_dict in split_state_dicts.items(): + # TODO: We should probably use the original shard name and just add a prefix + # however, having the SHARD_FNAME standardizes our checkpoints + shard_name = SHARD_FNAME.format( + cpt_idx=f"{cpt_idx}".zfill(5), + num_shards=f"{num_shards}".zfill(5), + ) + map_original_name_to_new_name[cpt_idx] = shard_name + output_path = Path.joinpath(model_file_path, shard_name) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path = output_path.with_suffix(".safetensors") + save_file(model_state_dict, output_path, metadata={"format": "pt"}) + + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" + ) + + # Save the appropriate index file based on serialization format + # e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}} + weight_map = { + k: map_original_name_to_new_name[cpt_idx] + ".safetensors" + for k, cpt_idx in self._weight_map.items() + } + index_file_name = SAFETENSOR_INDEX_FNAME + + index_path = Path.joinpath(model_file_path, index_file_name) + + index_data = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map, + } + with open(index_path, "w") as f: + json.dump(index_data, f, indent=2) + + if training.ADAPTER_KEY in state_dict: + + # TODO: saving it "as is" is a requirement because, if we only save with + # convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn + # convert_weights.peft_to_tune. The .pt format is not needed, but + # it is an easy way to distinguish the adapters. Ideally we should save only one. + output_path = Path.joinpath( + model_file_path, ADAPTER_MODEL_FNAME + ).with_suffix(".pt") + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(state_dict[training.ADAPTER_KEY], output_path) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" + ) + + state_dict[training.ADAPTER_KEY] = ( + convert_weights.tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], + num_heads=config["num_attention_heads"], + num_kv_heads=config["num_key_value_heads"], + dim=config["hidden_size"], + head_dim=config.get("head_dim", None), + ) + ) + output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path = output_path.with_suffix(".safetensors") + save_file( + state_dict[training.ADAPTER_KEY], + output_path, + metadata={"format": "pt"}, + ) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" + ) + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + if training.ADAPTER_CONFIG in state_dict: + state_dict[training.ADAPTER_CONFIG] = ( + convert_weights.tune_to_peft_adapter_config( + adapter_config=state_dict[training.ADAPTER_CONFIG], + base_model_name_or_path=self.repo_id, + ) + ) + + output_path = Path.joinpath( + model_file_path, ADAPTER_CONFIG_FNAME + ).with_suffix(".json") + with open(output_path, "w") as f: + json.dump(state_dict[training.ADAPTER_CONFIG], f) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" + ) + + # Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch} + # So its easy to run inference with the model using this epoch's checkpoint + copy_files( + self._checkpoint_dir, + model_file_path, + ignore_suffixes=SUFFIXES_TO_NOT_COPY, + ) + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." ) else: - # llama3_2 has tied weights, so we need to add the output.weight key - if ( - self._model_type == ModelType.LLAMA3_2 - and "output.weight" not in model_state_dict - ): - model_state_dict["output.weight"] = model_state_dict[ - "tok_embeddings.weight" - ] - - state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( - model_state_dict + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." ) - - model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") - - torch.save(state_dict[training.MODEL_KEY], model_file_name) - logger.info( - "Model checkpoint of size " - f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB " - f"saved to {model_file_name}" - ) - - if training.ADAPTER_KEY in state_dict: - adapter_file_path = model_file_path / "adapter" - adapter_file_path.mkdir(parents=True, exist_ok=True) - adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth") - torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name) - logger.info( - "Adapter checkpoint of size " - f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB " - f"saved to {adapter_file_name}" - ) - - elif adapter_only: - raise ValueError( - "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." - ) - - print("model_file_path", str(model_file_path)) + else: + raise ValueError(f"Unsupported checkpoint format: {format}") return str(model_file_path) diff --git a/llama_stack/providers/inline/post_training/torchtune/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py index 3ffa55c70..ecb80dd29 100644 --- a/llama_stack/providers/inline/post_training/torchtune/config.py +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -11,3 +11,4 @@ from pydantic import BaseModel class TorchtunePostTrainingConfig(BaseModel): torch_seed: Optional[int] = None + checkpoint_format: Optional[str] = "hf" diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 80e206ebb..e64676a5e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -15,24 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple import torch from llama_models.sku_list import resolve_model -from torch import nn -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training, utils as torchtune_utils -from torchtune.data import padded_collate_sft - -from torchtune.modules.loss import CEWithChunkedOutputLoss -from torchtune.modules.peft import ( - get_adapter_params, - get_adapter_state_dict, - get_lora_module_names, - get_merged_lora_ckpt, - set_trainable_params, - validate_missing_and_unexpected_for_lora, -) -from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup -from torchtune.training.metric_logging import DiskLogger -from tqdm import tqdm from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO @@ -60,6 +42,24 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import modules, training, utils as torchtune_utils +from torchtune.data import padded_collate_sft + +from torchtune.modules.loss import CEWithChunkedOutputLoss +from torchtune.modules.peft import ( + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup +from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm log = logging.getLogger(__name__) @@ -129,6 +129,7 @@ class LoraFinetuningSingleDevice: self.checkpoint_dir = model_checkpoint_dir(model) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) + self._checkpoint_format = config.checkpoint_format self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 @@ -444,6 +445,7 @@ class LoraFinetuningSingleDevice: return self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, + checkpoint_format=self._checkpoint_format, ) async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: