diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 882deec4f..ae78c227f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -4,12 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import os import shutil from pathlib import Path from typing import Any, Dict, List import torch +from safetensors.torch import save_file from torchtune import training from torchtune.models import convert_weights from torchtune.training.checkpointing._utils import ( @@ -103,7 +105,7 @@ class TorchtuneCheckpointer: Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" ) - if format == "meta": + if checkpoint_format == "meta": model_file_path.mkdir(parents=True, exist_ok=True) # copy the related files for inference @@ -176,79 +178,27 @@ class TorchtuneCheckpointer: ) print("model_file_path", str(model_file_path)) - elif format == "hf": + elif checkpoint_format == "hf": + # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now + # the config.json file contains model params needed for state dict conversion config = json.loads( - Path.joinpath(self._checkpoint_dir, "config.json").read_text() + Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text() ) - if not adapter_only: - state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( - state_dict[training.MODEL_KEY], - num_heads=config["num_attention_heads"], - num_kv_heads=config["num_key_value_heads"], - dim=config["hidden_size"], - head_dim=config.get("head_dim", None), - ) - # split the state_dict into separate dicts, one for each output checkpoint file - # e.g. split_state_dicts= { - # "0001": {"key1": tensor1, "key2": tensor2}, - # "0002": {"key3": tensor3} - # } - split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {} - total_size = 0 - for key, weight in state_dict[training.MODEL_KEY].items(): - cpt_idx = self._weight_map[key] - - # initialize dict - if cpt_idx not in split_state_dicts: - split_state_dicts[cpt_idx] = {} - - split_state_dicts[cpt_idx].update({key: weight}) - total_size += weight.numel() * weight.element_size() - - # write the partitioned state dicts to the right checkpoint file - # e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc - num_shards = len(split_state_dicts) - map_original_name_to_new_name = {} - for cpt_idx, model_state_dict in split_state_dicts.items(): - # TODO: We should probably use the original shard name and just add a prefix - # however, having the SHARD_FNAME standardizes our checkpoints - shard_name = SHARD_FNAME.format( - cpt_idx=f"{cpt_idx}".zfill(5), - num_shards=f"{num_shards}".zfill(5), - ) - map_original_name_to_new_name[cpt_idx] = shard_name - output_path = Path.joinpath(model_file_path, shard_name) - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path = output_path.with_suffix(".safetensors") - save_file(model_state_dict, output_path, metadata={"format": "pt"}) - - logger.info( - "Model checkpoint of size " - f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " - f"saved to {output_path}" - ) - - # Save the appropriate index file based on serialization format - # e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}} - weight_map = { - k: map_original_name_to_new_name[cpt_idx] + ".safetensors" - for k, cpt_idx in self._weight_map.items() - } - index_file_name = SAFETENSOR_INDEX_FNAME - - index_path = Path.joinpath(model_file_path, index_file_name) - - index_data = { - "metadata": {"total_size": total_size}, - "weight_map": weight_map, - } - with open(index_path, "w") as f: - json.dump(index_data, f, indent=2) + # repo_id is necessary for when saving an adapter config, so its compatible with HF. + # This json file is produced and saved in the download step. + # contents are {"repo_id": "some_model/some_model_version"} + repo_id_path = Path.joinpath( + self._checkpoint_dir.parent, REPO_ID_FNAME + ).with_suffix(".json") + self.repo_id = None + if repo_id_path.exists(): + with open(repo_id_path, "r") as json_file: + data = json.load(json_file) + self.repo_id = data.get("repo_id") if training.ADAPTER_KEY in state_dict: - # TODO: saving it "as is" is a requirement because, if we only save with # convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn # convert_weights.peft_to_tune. The .pt format is not needed, but @@ -273,7 +223,9 @@ class TorchtuneCheckpointer: head_dim=config.get("head_dim", None), ) ) - output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME) + output_path = Path.joinpath( + model_file_path, "adapter", ADAPTER_MODEL_FNAME + ) output_path.parent.mkdir(parents=True, exist_ok=True) output_path = output_path.with_suffix(".safetensors") save_file( @@ -300,7 +252,7 @@ class TorchtuneCheckpointer: ) output_path = Path.joinpath( - model_file_path, ADAPTER_CONFIG_FNAME + model_file_path, "adapter", ADAPTER_CONFIG_FNAME ).with_suffix(".json") with open(output_path, "w") as f: json.dump(state_dict[training.ADAPTER_CONFIG], f) @@ -313,7 +265,7 @@ class TorchtuneCheckpointer: # Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch} # So its easy to run inference with the model using this epoch's checkpoint copy_files( - self._checkpoint_dir, + self._checkpoint_dir.parent, model_file_path, ignore_suffixes=SUFFIXES_TO_NOT_COPY, ) diff --git a/llama_stack/providers/inline/post_training/torchtune/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py index ecb80dd29..34a48589d 100644 --- a/llama_stack/providers/inline/post_training/torchtune/config.py +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -11,4 +11,4 @@ from pydantic import BaseModel class TorchtunePostTrainingConfig(BaseModel): torch_seed: Optional[int] = None - checkpoint_format: Optional[str] = "hf" + checkpoint_format: Optional[str] = "meta"