mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
temp commit
This commit is contained in:
parent
07b87365ab
commit
b75e671c3b
3 changed files with 261 additions and 86 deletions
|
@ -12,7 +12,23 @@ from typing import Any, Dict, List
|
||||||
import torch
|
import torch
|
||||||
from torchtune import training
|
from torchtune import training
|
||||||
from torchtune.models import convert_weights
|
from torchtune.models import convert_weights
|
||||||
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
from torchtune.training.checkpointing._utils import (
|
||||||
|
ADAPTER_CONFIG_FNAME,
|
||||||
|
ADAPTER_MODEL_FNAME,
|
||||||
|
check_outdir_not_in_ckptdir,
|
||||||
|
copy_files,
|
||||||
|
get_adapter_checkpoint_path,
|
||||||
|
get_model_checkpoint_path,
|
||||||
|
get_recipe_checkpoint_path,
|
||||||
|
ModelType,
|
||||||
|
RECIPE_STATE_DIRNAME,
|
||||||
|
REPO_ID_FNAME,
|
||||||
|
safe_torch_load,
|
||||||
|
SAFETENSOR_INDEX_FNAME,
|
||||||
|
SHARD_FNAME,
|
||||||
|
SUFFIXES_TO_NOT_COPY,
|
||||||
|
TORCH_INDEX_FNAME,
|
||||||
|
)
|
||||||
from torchtune.utils._logging import get_logger
|
from torchtune.utils._logging import get_logger
|
||||||
|
|
||||||
logger = get_logger("DEBUG")
|
logger = get_logger("DEBUG")
|
||||||
|
@ -81,83 +97,239 @@ class TorchtuneCheckpointer:
|
||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
epoch: int,
|
epoch: int,
|
||||||
adapter_only: bool = False,
|
adapter_only: bool = False,
|
||||||
|
checkpoint_format: str = "meta",
|
||||||
) -> str:
|
) -> str:
|
||||||
model_file_path = (
|
model_file_path = (
|
||||||
Path(self._output_dir)
|
Path(self._output_dir)
|
||||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
)
|
)
|
||||||
|
if format == "meta":
|
||||||
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
# copy the related files for inference
|
||||||
|
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||||
# copy the related files for inference
|
if source_path.exists():
|
||||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
shutil.copy(
|
||||||
if source_path.exists():
|
source_path,
|
||||||
shutil.copy(
|
Path.joinpath(model_file_path, "params.json"),
|
||||||
source_path,
|
)
|
||||||
Path.joinpath(model_file_path, "params.json"),
|
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||||
)
|
if source_path.exists():
|
||||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
shutil.copy(
|
||||||
if source_path.exists():
|
source_path,
|
||||||
shutil.copy(
|
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||||
source_path,
|
)
|
||||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||||
)
|
if source_path.exists():
|
||||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
shutil.copy(
|
||||||
if source_path.exists():
|
source_path,
|
||||||
shutil.copy(
|
Path.joinpath(model_file_path, "orig_params.json"),
|
||||||
source_path,
|
|
||||||
Path.joinpath(model_file_path, "orig_params.json"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not adapter_only:
|
|
||||||
model_state_dict = state_dict[training.MODEL_KEY]
|
|
||||||
if self._model_type == ModelType.LLAMA3_VISION:
|
|
||||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
|
||||||
llama3_vision_tune_to_meta,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
if not adapter_only:
|
||||||
model_state_dict
|
model_state_dict = state_dict[training.MODEL_KEY]
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_tune_to_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||||
|
if (
|
||||||
|
self._model_type == ModelType.LLAMA3_2
|
||||||
|
and "output.weight" not in model_state_dict
|
||||||
|
):
|
||||||
|
model_state_dict["output.weight"] = model_state_dict[
|
||||||
|
"tok_embeddings.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||||
|
|
||||||
|
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Model checkpoint of size "
|
||||||
|
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {model_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if training.ADAPTER_KEY in state_dict:
|
||||||
|
adapter_file_path = model_file_path / "adapter"
|
||||||
|
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||||
|
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {adapter_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif adapter_only:
|
||||||
|
raise ValueError(
|
||||||
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
print("model_file_path", str(model_file_path))
|
||||||
|
elif format == "hf":
|
||||||
|
# the config.json file contains model params needed for state dict conversion
|
||||||
|
config = json.loads(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "config.json").read_text()
|
||||||
|
)
|
||||||
|
if not adapter_only:
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf(
|
||||||
|
state_dict[training.MODEL_KEY],
|
||||||
|
num_heads=config["num_attention_heads"],
|
||||||
|
num_kv_heads=config["num_key_value_heads"],
|
||||||
|
dim=config["hidden_size"],
|
||||||
|
head_dim=config.get("head_dim", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# split the state_dict into separate dicts, one for each output checkpoint file
|
||||||
|
# e.g. split_state_dicts= {
|
||||||
|
# "0001": {"key1": tensor1, "key2": tensor2},
|
||||||
|
# "0002": {"key3": tensor3}
|
||||||
|
# }
|
||||||
|
split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
|
total_size = 0
|
||||||
|
for key, weight in state_dict[training.MODEL_KEY].items():
|
||||||
|
cpt_idx = self._weight_map[key]
|
||||||
|
|
||||||
|
# initialize dict
|
||||||
|
if cpt_idx not in split_state_dicts:
|
||||||
|
split_state_dicts[cpt_idx] = {}
|
||||||
|
|
||||||
|
split_state_dicts[cpt_idx].update({key: weight})
|
||||||
|
total_size += weight.numel() * weight.element_size()
|
||||||
|
|
||||||
|
# write the partitioned state dicts to the right checkpoint file
|
||||||
|
# e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc
|
||||||
|
num_shards = len(split_state_dicts)
|
||||||
|
map_original_name_to_new_name = {}
|
||||||
|
for cpt_idx, model_state_dict in split_state_dicts.items():
|
||||||
|
# TODO: We should probably use the original shard name and just add a prefix
|
||||||
|
# however, having the SHARD_FNAME standardizes our checkpoints
|
||||||
|
shard_name = SHARD_FNAME.format(
|
||||||
|
cpt_idx=f"{cpt_idx}".zfill(5),
|
||||||
|
num_shards=f"{num_shards}".zfill(5),
|
||||||
|
)
|
||||||
|
map_original_name_to_new_name[cpt_idx] = shard_name
|
||||||
|
output_path = Path.joinpath(model_file_path, shard_name)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = output_path.with_suffix(".safetensors")
|
||||||
|
save_file(model_state_dict, output_path, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Model checkpoint of size "
|
||||||
|
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||||
|
f"saved to {output_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the appropriate index file based on serialization format
|
||||||
|
# e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}}
|
||||||
|
weight_map = {
|
||||||
|
k: map_original_name_to_new_name[cpt_idx] + ".safetensors"
|
||||||
|
for k, cpt_idx in self._weight_map.items()
|
||||||
|
}
|
||||||
|
index_file_name = SAFETENSOR_INDEX_FNAME
|
||||||
|
|
||||||
|
index_path = Path.joinpath(model_file_path, index_file_name)
|
||||||
|
|
||||||
|
index_data = {
|
||||||
|
"metadata": {"total_size": total_size},
|
||||||
|
"weight_map": weight_map,
|
||||||
|
}
|
||||||
|
with open(index_path, "w") as f:
|
||||||
|
json.dump(index_data, f, indent=2)
|
||||||
|
|
||||||
|
if training.ADAPTER_KEY in state_dict:
|
||||||
|
|
||||||
|
# TODO: saving it "as is" is a requirement because, if we only save with
|
||||||
|
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
||||||
|
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
||||||
|
# it is an easy way to distinguish the adapters. Ideally we should save only one.
|
||||||
|
output_path = Path.joinpath(
|
||||||
|
model_file_path, ADAPTER_MODEL_FNAME
|
||||||
|
).with_suffix(".pt")
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(state_dict[training.ADAPTER_KEY], output_path)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||||
|
f"saved to {output_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.ADAPTER_KEY] = (
|
||||||
|
convert_weights.tune_to_peft_adapter_weights(
|
||||||
|
state_dict[training.ADAPTER_KEY],
|
||||||
|
num_heads=config["num_attention_heads"],
|
||||||
|
num_kv_heads=config["num_key_value_heads"],
|
||||||
|
dim=config["hidden_size"],
|
||||||
|
head_dim=config.get("head_dim", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = output_path.with_suffix(".safetensors")
|
||||||
|
save_file(
|
||||||
|
state_dict[training.ADAPTER_KEY],
|
||||||
|
output_path,
|
||||||
|
metadata={"format": "pt"},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||||
|
f"saved to {output_path}"
|
||||||
|
)
|
||||||
|
elif adapter_only:
|
||||||
|
raise ValueError(
|
||||||
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
if training.ADAPTER_CONFIG in state_dict:
|
||||||
|
state_dict[training.ADAPTER_CONFIG] = (
|
||||||
|
convert_weights.tune_to_peft_adapter_config(
|
||||||
|
adapter_config=state_dict[training.ADAPTER_CONFIG],
|
||||||
|
base_model_name_or_path=self.repo_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = Path.joinpath(
|
||||||
|
model_file_path, ADAPTER_CONFIG_FNAME
|
||||||
|
).with_suffix(".json")
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||||
|
f"saved to {output_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
||||||
|
# So its easy to run inference with the model using this epoch's checkpoint
|
||||||
|
copy_files(
|
||||||
|
self._checkpoint_dir,
|
||||||
|
model_file_path,
|
||||||
|
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||||
|
)
|
||||||
|
logger.info("Saving final epoch checkpoint.")
|
||||||
|
if adapter_only:
|
||||||
|
logger.info(
|
||||||
|
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
|
||||||
|
"You need to merge the adapter weights into your base model for further use. "
|
||||||
|
f"See {self.__class__.__name__}.save_checkpoint for more details."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
logger.info(
|
||||||
if (
|
"The full model checkpoint, including all weights and configurations, has been saved successfully."
|
||||||
self._model_type == ModelType.LLAMA3_2
|
"You can now use this checkpoint for further training or inference."
|
||||||
and "output.weight" not in model_state_dict
|
|
||||||
):
|
|
||||||
model_state_dict["output.weight"] = model_state_dict[
|
|
||||||
"tok_embeddings.weight"
|
|
||||||
]
|
|
||||||
|
|
||||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
|
||||||
model_state_dict
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
raise ValueError(f"Unsupported checkpoint format: {format}")
|
||||||
|
|
||||||
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
|
||||||
logger.info(
|
|
||||||
"Model checkpoint of size "
|
|
||||||
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
|
||||||
f"saved to {model_file_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if training.ADAPTER_KEY in state_dict:
|
|
||||||
adapter_file_path = model_file_path / "adapter"
|
|
||||||
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
|
||||||
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
|
||||||
logger.info(
|
|
||||||
"Adapter checkpoint of size "
|
|
||||||
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
|
||||||
f"saved to {adapter_file_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif adapter_only:
|
|
||||||
raise ValueError(
|
|
||||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
|
||||||
)
|
|
||||||
|
|
||||||
print("model_file_path", str(model_file_path))
|
|
||||||
|
|
||||||
return str(model_file_path)
|
return str(model_file_path)
|
||||||
|
|
|
@ -11,3 +11,4 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
class TorchtunePostTrainingConfig(BaseModel):
|
class TorchtunePostTrainingConfig(BaseModel):
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
|
checkpoint_format: Optional[str] = "hf"
|
||||||
|
|
|
@ -15,24 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from torch import nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
|
||||||
from torchtune import modules, training, utils as torchtune_utils
|
|
||||||
from torchtune.data import padded_collate_sft
|
|
||||||
|
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
|
||||||
from torchtune.modules.peft import (
|
|
||||||
get_adapter_params,
|
|
||||||
get_adapter_state_dict,
|
|
||||||
get_lora_module_names,
|
|
||||||
get_merged_lora_ckpt,
|
|
||||||
set_trainable_params,
|
|
||||||
validate_missing_and_unexpected_for_lora,
|
|
||||||
)
|
|
||||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -60,6 +42,24 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
TorchtunePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torchtune import modules, training, utils as torchtune_utils
|
||||||
|
from torchtune.data import padded_collate_sft
|
||||||
|
|
||||||
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
|
from torchtune.modules.peft import (
|
||||||
|
get_adapter_params,
|
||||||
|
get_adapter_state_dict,
|
||||||
|
get_lora_module_names,
|
||||||
|
get_merged_lora_ckpt,
|
||||||
|
set_trainable_params,
|
||||||
|
validate_missing_and_unexpected_for_lora,
|
||||||
|
)
|
||||||
|
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||||
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -129,6 +129,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
|
self._checkpoint_format = config.checkpoint_format
|
||||||
|
|
||||||
self.seed = training.set_seed(seed=config.torch_seed)
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
self.epochs_run = 0
|
self.epochs_run = 0
|
||||||
|
@ -444,6 +445,7 @@ class LoraFinetuningSingleDevice:
|
||||||
return self._checkpointer.save_checkpoint(
|
return self._checkpointer.save_checkpoint(
|
||||||
ckpt_dict,
|
ckpt_dict,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
|
checkpoint_format=self._checkpoint_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue