diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index ae78c227f..1c9febc73 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -17,19 +17,11 @@ from torchtune.models import convert_weights from torchtune.training.checkpointing._utils import ( ADAPTER_CONFIG_FNAME, ADAPTER_MODEL_FNAME, - check_outdir_not_in_ckptdir, copy_files, - get_adapter_checkpoint_path, - get_model_checkpoint_path, - get_recipe_checkpoint_path, ModelType, - RECIPE_STATE_DIRNAME, REPO_ID_FNAME, safe_torch_load, - SAFETENSOR_INDEX_FNAME, - SHARD_FNAME, SUFFIXES_TO_NOT_COPY, - TORCH_INDEX_FNAME, ) from torchtune.utils._logging import get_logger @@ -176,8 +168,6 @@ class TorchtuneCheckpointer: raise ValueError( "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." ) - - print("model_file_path", str(model_file_path)) elif checkpoint_format == "hf": # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now @@ -238,7 +228,7 @@ class TorchtuneCheckpointer: f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " f"saved to {output_path}" ) - elif adapter_only: + else: raise ValueError( "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." ) @@ -269,18 +259,6 @@ class TorchtuneCheckpointer: model_file_path, ignore_suffixes=SUFFIXES_TO_NOT_COPY, ) - logger.info("Saving final epoch checkpoint.") - if adapter_only: - logger.info( - "Please note that you have set adapter_only=True, so only adapter weights will be saved." - "You need to merge the adapter weights into your base model for further use. " - f"See {self.__class__.__name__}.save_checkpoint for more details." - ) - else: - logger.info( - "The full model checkpoint, including all weights and configurations, has been saved successfully." - "You can now use this checkpoint for further training or inference." - ) else: raise ValueError(f"Unsupported checkpoint format: {format}") diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index e64676a5e..ac61bc6cc 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -15,6 +15,24 @@ from typing import Any, Dict, List, Optional, Tuple import torch from llama_models.sku_list import resolve_model +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import modules, training, utils as torchtune_utils +from torchtune.data import padded_collate_sft + +from torchtune.modules.loss import CEWithChunkedOutputLoss +from torchtune.modules.peft import ( + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup +from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO @@ -42,24 +60,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -from torch import nn -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training, utils as torchtune_utils -from torchtune.data import padded_collate_sft - -from torchtune.modules.loss import CEWithChunkedOutputLoss -from torchtune.modules.peft import ( - get_adapter_params, - get_adapter_state_dict, - get_lora_module_names, - get_merged_lora_ckpt, - set_trainable_params, - validate_missing_and_unexpected_for_lora, -) -from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup -from torchtune.training.metric_logging import DiskLogger -from tqdm import tqdm log = logging.getLogger(__name__) @@ -490,7 +490,7 @@ class LoraFinetuningSingleDevice: # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True metric_logger = DiskLogger( - log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" + log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log" ) self._training_sampler.set_epoch(curr_epoch) loss_to_log = 0.0