mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
commit
This commit is contained in:
parent
bbb1542b95
commit
09e9445a11
2 changed files with 24 additions and 72 deletions
|
@ -4,12 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
from torchtune import training
|
from torchtune import training
|
||||||
from torchtune.models import convert_weights
|
from torchtune.models import convert_weights
|
||||||
from torchtune.training.checkpointing._utils import (
|
from torchtune.training.checkpointing._utils import (
|
||||||
|
@ -103,7 +105,7 @@ class TorchtuneCheckpointer:
|
||||||
Path(self._output_dir)
|
Path(self._output_dir)
|
||||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
)
|
)
|
||||||
if format == "meta":
|
if checkpoint_format == "meta":
|
||||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# copy the related files for inference
|
# copy the related files for inference
|
||||||
|
@ -176,79 +178,27 @@ class TorchtuneCheckpointer:
|
||||||
)
|
)
|
||||||
|
|
||||||
print("model_file_path", str(model_file_path))
|
print("model_file_path", str(model_file_path))
|
||||||
elif format == "hf":
|
elif checkpoint_format == "hf":
|
||||||
|
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||||
|
|
||||||
# the config.json file contains model params needed for state dict conversion
|
# the config.json file contains model params needed for state dict conversion
|
||||||
config = json.loads(
|
config = json.loads(
|
||||||
Path.joinpath(self._checkpoint_dir, "config.json").read_text()
|
Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text()
|
||||||
)
|
)
|
||||||
if not adapter_only:
|
|
||||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf(
|
|
||||||
state_dict[training.MODEL_KEY],
|
|
||||||
num_heads=config["num_attention_heads"],
|
|
||||||
num_kv_heads=config["num_key_value_heads"],
|
|
||||||
dim=config["hidden_size"],
|
|
||||||
head_dim=config.get("head_dim", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# split the state_dict into separate dicts, one for each output checkpoint file
|
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
|
||||||
# e.g. split_state_dicts= {
|
# This json file is produced and saved in the download step.
|
||||||
# "0001": {"key1": tensor1, "key2": tensor2},
|
# contents are {"repo_id": "some_model/some_model_version"}
|
||||||
# "0002": {"key3": tensor3}
|
repo_id_path = Path.joinpath(
|
||||||
# }
|
self._checkpoint_dir.parent, REPO_ID_FNAME
|
||||||
split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
|
).with_suffix(".json")
|
||||||
total_size = 0
|
self.repo_id = None
|
||||||
for key, weight in state_dict[training.MODEL_KEY].items():
|
if repo_id_path.exists():
|
||||||
cpt_idx = self._weight_map[key]
|
with open(repo_id_path, "r") as json_file:
|
||||||
|
data = json.load(json_file)
|
||||||
# initialize dict
|
self.repo_id = data.get("repo_id")
|
||||||
if cpt_idx not in split_state_dicts:
|
|
||||||
split_state_dicts[cpt_idx] = {}
|
|
||||||
|
|
||||||
split_state_dicts[cpt_idx].update({key: weight})
|
|
||||||
total_size += weight.numel() * weight.element_size()
|
|
||||||
|
|
||||||
# write the partitioned state dicts to the right checkpoint file
|
|
||||||
# e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc
|
|
||||||
num_shards = len(split_state_dicts)
|
|
||||||
map_original_name_to_new_name = {}
|
|
||||||
for cpt_idx, model_state_dict in split_state_dicts.items():
|
|
||||||
# TODO: We should probably use the original shard name and just add a prefix
|
|
||||||
# however, having the SHARD_FNAME standardizes our checkpoints
|
|
||||||
shard_name = SHARD_FNAME.format(
|
|
||||||
cpt_idx=f"{cpt_idx}".zfill(5),
|
|
||||||
num_shards=f"{num_shards}".zfill(5),
|
|
||||||
)
|
|
||||||
map_original_name_to_new_name[cpt_idx] = shard_name
|
|
||||||
output_path = Path.joinpath(model_file_path, shard_name)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
output_path = output_path.with_suffix(".safetensors")
|
|
||||||
save_file(model_state_dict, output_path, metadata={"format": "pt"})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Model checkpoint of size "
|
|
||||||
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
|
||||||
f"saved to {output_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the appropriate index file based on serialization format
|
|
||||||
# e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}}
|
|
||||||
weight_map = {
|
|
||||||
k: map_original_name_to_new_name[cpt_idx] + ".safetensors"
|
|
||||||
for k, cpt_idx in self._weight_map.items()
|
|
||||||
}
|
|
||||||
index_file_name = SAFETENSOR_INDEX_FNAME
|
|
||||||
|
|
||||||
index_path = Path.joinpath(model_file_path, index_file_name)
|
|
||||||
|
|
||||||
index_data = {
|
|
||||||
"metadata": {"total_size": total_size},
|
|
||||||
"weight_map": weight_map,
|
|
||||||
}
|
|
||||||
with open(index_path, "w") as f:
|
|
||||||
json.dump(index_data, f, indent=2)
|
|
||||||
|
|
||||||
if training.ADAPTER_KEY in state_dict:
|
if training.ADAPTER_KEY in state_dict:
|
||||||
|
|
||||||
# TODO: saving it "as is" is a requirement because, if we only save with
|
# TODO: saving it "as is" is a requirement because, if we only save with
|
||||||
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
||||||
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
||||||
|
@ -273,7 +223,9 @@ class TorchtuneCheckpointer:
|
||||||
head_dim=config.get("head_dim", None),
|
head_dim=config.get("head_dim", None),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME)
|
output_path = Path.joinpath(
|
||||||
|
model_file_path, "adapter", ADAPTER_MODEL_FNAME
|
||||||
|
)
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
output_path = output_path.with_suffix(".safetensors")
|
output_path = output_path.with_suffix(".safetensors")
|
||||||
save_file(
|
save_file(
|
||||||
|
@ -300,7 +252,7 @@ class TorchtuneCheckpointer:
|
||||||
)
|
)
|
||||||
|
|
||||||
output_path = Path.joinpath(
|
output_path = Path.joinpath(
|
||||||
model_file_path, ADAPTER_CONFIG_FNAME
|
model_file_path, "adapter", ADAPTER_CONFIG_FNAME
|
||||||
).with_suffix(".json")
|
).with_suffix(".json")
|
||||||
with open(output_path, "w") as f:
|
with open(output_path, "w") as f:
|
||||||
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
||||||
|
@ -313,7 +265,7 @@ class TorchtuneCheckpointer:
|
||||||
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
||||||
# So its easy to run inference with the model using this epoch's checkpoint
|
# So its easy to run inference with the model using this epoch's checkpoint
|
||||||
copy_files(
|
copy_files(
|
||||||
self._checkpoint_dir,
|
self._checkpoint_dir.parent,
|
||||||
model_file_path,
|
model_file_path,
|
||||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,4 +11,4 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
class TorchtunePostTrainingConfig(BaseModel):
|
class TorchtunePostTrainingConfig(BaseModel):
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
checkpoint_format: Optional[str] = "hf"
|
checkpoint_format: Optional[str] = "meta"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue