This commit is contained in:
Botao Chen 2025-01-22 16:21:40 -08:00
parent bbb1542b95
commit 09e9445a11
2 changed files with 24 additions and 72 deletions

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from safetensors.torch import save_file
from torchtune import training from torchtune import training
from torchtune.models import convert_weights from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ( from torchtune.training.checkpointing._utils import (
@ -103,7 +105,7 @@ class TorchtuneCheckpointer:
Path(self._output_dir) Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}" / f"{self._model_id}-{self._training_algorithm}-{epoch}"
) )
if format == "meta": if checkpoint_format == "meta":
model_file_path.mkdir(parents=True, exist_ok=True) model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference # copy the related files for inference
@ -176,79 +178,27 @@ class TorchtuneCheckpointer:
) )
print("model_file_path", str(model_file_path)) print("model_file_path", str(model_file_path))
elif format == "hf": elif checkpoint_format == "hf":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
# the config.json file contains model params needed for state dict conversion # the config.json file contains model params needed for state dict conversion
config = json.loads( config = json.loads(
Path.joinpath(self._checkpoint_dir, "config.json").read_text() Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text()
) )
if not adapter_only:
state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[training.MODEL_KEY],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
dim=config["hidden_size"],
head_dim=config.get("head_dim", None),
)
# split the state_dict into separate dicts, one for each output checkpoint file # repo_id is necessary for when saving an adapter config, so its compatible with HF.
# e.g. split_state_dicts= { # This json file is produced and saved in the download step.
# "0001": {"key1": tensor1, "key2": tensor2}, # contents are {"repo_id": "some_model/some_model_version"}
# "0002": {"key3": tensor3} repo_id_path = Path.joinpath(
# } self._checkpoint_dir.parent, REPO_ID_FNAME
split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {} ).with_suffix(".json")
total_size = 0 self.repo_id = None
for key, weight in state_dict[training.MODEL_KEY].items(): if repo_id_path.exists():
cpt_idx = self._weight_map[key] with open(repo_id_path, "r") as json_file:
data = json.load(json_file)
# initialize dict self.repo_id = data.get("repo_id")
if cpt_idx not in split_state_dicts:
split_state_dicts[cpt_idx] = {}
split_state_dicts[cpt_idx].update({key: weight})
total_size += weight.numel() * weight.element_size()
# write the partitioned state dicts to the right checkpoint file
# e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc
num_shards = len(split_state_dicts)
map_original_name_to_new_name = {}
for cpt_idx, model_state_dict in split_state_dicts.items():
# TODO: We should probably use the original shard name and just add a prefix
# however, having the SHARD_FNAME standardizes our checkpoints
shard_name = SHARD_FNAME.format(
cpt_idx=f"{cpt_idx}".zfill(5),
num_shards=f"{num_shards}".zfill(5),
)
map_original_name_to_new_name[cpt_idx] = shard_name
output_path = Path.joinpath(model_file_path, shard_name)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors")
save_file(model_state_dict, output_path, metadata={"format": "pt"})
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)
# Save the appropriate index file based on serialization format
# e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}}
weight_map = {
k: map_original_name_to_new_name[cpt_idx] + ".safetensors"
for k, cpt_idx in self._weight_map.items()
}
index_file_name = SAFETENSOR_INDEX_FNAME
index_path = Path.joinpath(model_file_path, index_file_name)
index_data = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(index_path, "w") as f:
json.dump(index_data, f, indent=2)
if training.ADAPTER_KEY in state_dict: if training.ADAPTER_KEY in state_dict:
# TODO: saving it "as is" is a requirement because, if we only save with # TODO: saving it "as is" is a requirement because, if we only save with
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn # convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
# convert_weights.peft_to_tune. The .pt format is not needed, but # convert_weights.peft_to_tune. The .pt format is not needed, but
@ -273,7 +223,9 @@ class TorchtuneCheckpointer:
head_dim=config.get("head_dim", None), head_dim=config.get("head_dim", None),
) )
) )
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME) output_path = Path.joinpath(
model_file_path, "adapter", ADAPTER_MODEL_FNAME
)
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors") output_path = output_path.with_suffix(".safetensors")
save_file( save_file(
@ -300,7 +252,7 @@ class TorchtuneCheckpointer:
) )
output_path = Path.joinpath( output_path = Path.joinpath(
model_file_path, ADAPTER_CONFIG_FNAME model_file_path, "adapter", ADAPTER_CONFIG_FNAME
).with_suffix(".json") ).with_suffix(".json")
with open(output_path, "w") as f: with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f) json.dump(state_dict[training.ADAPTER_CONFIG], f)
@ -313,7 +265,7 @@ class TorchtuneCheckpointer:
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch} # Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
# So its easy to run inference with the model using this epoch's checkpoint # So its easy to run inference with the model using this epoch's checkpoint
copy_files( copy_files(
self._checkpoint_dir, self._checkpoint_dir.parent,
model_file_path, model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY, ignore_suffixes=SUFFIXES_TO_NOT_COPY,
) )

View file

@ -11,4 +11,4 @@ from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel): class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
checkpoint_format: Optional[str] = "hf" checkpoint_format: Optional[str] = "meta"