This commit is contained in:
Botao Chen 2025-01-22 16:54:16 -08:00
parent 09e9445a11
commit a57f46e363
2 changed files with 20 additions and 42 deletions

View file

@ -17,19 +17,11 @@ from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
check_outdir_not_in_ckptdir,
copy_files,
get_adapter_checkpoint_path,
get_model_checkpoint_path,
get_recipe_checkpoint_path,
ModelType,
RECIPE_STATE_DIRNAME,
REPO_ID_FNAME,
safe_torch_load,
SAFETENSOR_INDEX_FNAME,
SHARD_FNAME,
SUFFIXES_TO_NOT_COPY,
TORCH_INDEX_FNAME,
)
from torchtune.utils._logging import get_logger
@ -176,8 +168,6 @@ class TorchtuneCheckpointer:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
print("model_file_path", str(model_file_path))
elif checkpoint_format == "hf":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
@ -238,7 +228,7 @@ class TorchtuneCheckpointer:
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)
elif adapter_only:
else:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
@ -269,18 +259,6 @@ class TorchtuneCheckpointer:
model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)
logger.info("Saving final epoch checkpoint.")
if adapter_only:
logger.info(
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
"You need to merge the adapter weights into your base model for further use. "
f"See {self.__class__.__name__}.save_checkpoint for more details."
)
else:
logger.info(
"The full model checkpoint, including all weights and configurations, has been saved successfully."
"You can now use this checkpoint for further training or inference."
)
else:
raise ValueError(f"Unsupported checkpoint format: {format}")

View file

@ -15,6 +15,24 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
@ -42,24 +60,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
log = logging.getLogger(__name__)
@ -490,7 +490,7 @@ class LoraFinetuningSingleDevice:
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
)
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0