mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
address comment
This commit is contained in:
parent
0a62ceecb7
commit
0fb674d77b
4 changed files with 155 additions and 184 deletions
|
@ -17,11 +17,11 @@ from torchtune.models import convert_weights
|
|||
from torchtune.training.checkpointing._utils import (
|
||||
ADAPTER_CONFIG_FNAME,
|
||||
ADAPTER_MODEL_FNAME,
|
||||
copy_files,
|
||||
ModelType,
|
||||
REPO_ID_FNAME,
|
||||
safe_torch_load,
|
||||
SUFFIXES_TO_NOT_COPY,
|
||||
ModelType,
|
||||
copy_files,
|
||||
safe_torch_load,
|
||||
)
|
||||
from torchtune.utils._logging import get_logger
|
||||
|
||||
|
@ -52,9 +52,7 @@ class TorchtuneCheckpointer:
|
|||
self._model_type = ModelType[model_type]
|
||||
self._output_dir = output_dir
|
||||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(
|
||||
self._checkpoint_dir, self._checkpoint_file
|
||||
)
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -67,13 +65,9 @@ class TorchtuneCheckpointer:
|
|||
llama3_vision_meta_to_tune,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
|
||||
else:
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
|
||||
|
||||
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2:
|
||||
|
@ -93,173 +87,154 @@ class TorchtuneCheckpointer:
|
|||
adapter_only: bool = False,
|
||||
checkpoint_format: str = "meta",
|
||||
) -> str:
|
||||
model_file_path = (
|
||||
Path(self._output_dir)
|
||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
)
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta":
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_tune_to_meta,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
else:
|
||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||
if (
|
||||
self._model_type == ModelType.LLAMA3_2
|
||||
and "output.weight" not in model_state_dict
|
||||
):
|
||||
model_state_dict["output.weight"] = model_state_dict[
|
||||
"tok_embeddings.weight"
|
||||
]
|
||||
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
|
||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||
|
||||
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||
logger.info(
|
||||
"Model checkpoint of size "
|
||||
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {model_file_name}"
|
||||
)
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
adapter_file_path = model_file_path / "adapter"
|
||||
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {adapter_file_name}"
|
||||
)
|
||||
|
||||
elif adapter_only:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
elif checkpoint_format == "hf":
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(
|
||||
Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text()
|
||||
)
|
||||
|
||||
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
|
||||
# This json file is produced and saved in the download step.
|
||||
# contents are {"repo_id": "some_model/some_model_version"}
|
||||
repo_id_path = Path.joinpath(
|
||||
self._checkpoint_dir.parent, REPO_ID_FNAME
|
||||
).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
# TODO: saving it "as is" is a requirement because, if we only save with
|
||||
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
||||
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
||||
# it is an easy way to distinguish the adapters. Ideally we should save only one.
|
||||
output_path = Path.joinpath(
|
||||
model_file_path, ADAPTER_MODEL_FNAME
|
||||
).with_suffix(".pt")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(state_dict[training.ADAPTER_KEY], output_path)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||
f"saved to {output_path}"
|
||||
)
|
||||
|
||||
state_dict[training.ADAPTER_KEY] = (
|
||||
convert_weights.tune_to_peft_adapter_weights(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
num_heads=config["num_attention_heads"],
|
||||
num_kv_heads=config["num_key_value_heads"],
|
||||
dim=config["hidden_size"],
|
||||
head_dim=config.get("head_dim", None),
|
||||
)
|
||||
)
|
||||
output_path = Path.joinpath(
|
||||
model_file_path, "adapter", ADAPTER_MODEL_FNAME
|
||||
)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_path.with_suffix(".safetensors")
|
||||
save_file(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
output_path,
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||
f"saved to {output_path}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
if training.ADAPTER_CONFIG in state_dict:
|
||||
state_dict[training.ADAPTER_CONFIG] = (
|
||||
convert_weights.tune_to_peft_adapter_config(
|
||||
adapter_config=state_dict[training.ADAPTER_CONFIG],
|
||||
base_model_name_or_path=self.repo_id,
|
||||
)
|
||||
)
|
||||
|
||||
output_path = Path.joinpath(
|
||||
model_file_path, "adapter", ADAPTER_CONFIG_FNAME
|
||||
).with_suffix(".json")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||
f"saved to {output_path}"
|
||||
)
|
||||
|
||||
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
||||
# So its easy to run inference with the model using this epoch's checkpoint
|
||||
copy_files(
|
||||
self._checkpoint_dir.parent,
|
||||
model_file_path,
|
||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||
)
|
||||
self._save_hf_format_checkpoint(model_file_path, state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported checkpoint format: {format}")
|
||||
return str(model_file_path)
|
||||
|
||||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_tune_to_meta,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
|
||||
else:
|
||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
|
||||
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
|
||||
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
|
||||
|
||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||
|
||||
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||
logger.info(
|
||||
"Model checkpoint of size "
|
||||
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {model_file_name}"
|
||||
)
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
adapter_file_path = model_file_path / "adapter"
|
||||
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {adapter_file_name}"
|
||||
)
|
||||
|
||||
elif adapter_only:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
||||
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
|
||||
# This json file is produced and saved in the download step.
|
||||
# contents are {"repo_id": "some_model/some_model_version"}
|
||||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
# TODO: saving it "as is" is a requirement because, if we only save with
|
||||
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
||||
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
||||
# it is an easy way to distinguish the adapters. Ideally we should save only one.
|
||||
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(state_dict[training.ADAPTER_KEY], output_path)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
|
||||
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
num_heads=config["num_attention_heads"],
|
||||
num_kv_heads=config["num_key_value_heads"],
|
||||
dim=config["hidden_size"],
|
||||
head_dim=config.get("head_dim", None),
|
||||
)
|
||||
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_path.with_suffix(".safetensors")
|
||||
save_file(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
output_path,
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
if training.ADAPTER_CONFIG in state_dict:
|
||||
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
|
||||
adapter_config=state_dict[training.ADAPTER_CONFIG],
|
||||
base_model_name_or_path=self.repo_id,
|
||||
)
|
||||
|
||||
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
|
||||
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
||||
# So its easy to run inference with the model using this epoch's checkpoint
|
||||
copy_files(
|
||||
self._checkpoint_dir.parent,
|
||||
model_file_path,
|
||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||
)
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[str] = "meta"
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
|
|
@ -16,7 +16,6 @@ distribution_spec:
|
|||
- inline::torchtune
|
||||
datasetio:
|
||||
- inline::localfs
|
||||
- remote::huggingface
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
agents:
|
||||
|
|
|
@ -38,9 +38,6 @@ providers:
|
|||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
datasetio:
|
||||
- provider_id: huggingface-0
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
|
@ -52,7 +49,7 @@ providers:
|
|||
- provider_id: torchtune-post-training
|
||||
provider_type: inline::torchtune
|
||||
config: {
|
||||
checkpoint_format: hf
|
||||
checkpoint_format: huggingface
|
||||
}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue