diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 0714046bf..7f1547657 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -23,6 +23,7 @@ from llama_stack.providers.inline.post_training.torchtune.common.checkpointer im from torch import nn from torchtune import utils as torchtune_utils from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm from llama_stack.apis.post_training import * # noqa from llama_stack.distribution.utils.model_utils import model_local_dir @@ -185,11 +186,21 @@ class LoraFinetuningSingleDevice: self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) log.info("Loss is initialized.") - self._sampler, self._dataloader = await self._setup_data( + self._training_sampler, self._training_dataloader = await self._setup_data( + dataset_id=self.training_config.data_config.dataset_id, tokenizer=self._tokenizer, shuffle=self._shuffle, batch_size=self._batch_size, ) + + if self.training_config.data_config.validation_dataset_id: + _, self._validation_dataloader = await self._setup_data( + dataset_id=self.training_config.data_config.validation_dataset_id, + tokenizer=self._tokenizer, + shuffle=False, + batch_size=self._batch_size, + ) + log.info("Dataset and Sampler are initialized.") # Number of training steps in each epoch depends on the number of batches produced @@ -197,7 +208,7 @@ class LoraFinetuningSingleDevice: # for logging and tracking training state. This should be computed after the dataloader # has been setup self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps + len(self._training_dataloader) // self._gradient_accumulation_steps ) if ( self.max_steps_per_epoch is not None @@ -316,17 +327,19 @@ class LoraFinetuningSingleDevice: return optimizer async def _setup_data( - self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int + self, + dataset_id: str, + tokenizer: Llama3Tokenizer, + shuffle: bool, + batch_size: int, ) -> Tuple[DistributedSampler, DataLoader]: - dataset_id = self.training_config.data_config.dataset_id - - async def fetch_rows(): + async def fetch_rows(dataset_id: str): return await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, ) - all_rows = await fetch_rows() + all_rows = await fetch_rows(dataset_id) rows = all_rows.rows # Curretly only support alpaca instruct dataset @@ -460,9 +473,11 @@ class LoraFinetuningSingleDevice: metric_logger = DiskLogger( log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" ) - self._sampler.set_epoch(curr_epoch) + self._training_sampler.set_epoch(curr_epoch) + loss_to_log = 0.0 - for idx, batch in enumerate(self._dataloader): + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._training_dataloader): if ( self.max_steps_per_epoch is not None and (idx // self._gradient_accumulation_steps) @@ -499,6 +514,12 @@ class LoraFinetuningSingleDevice: self.global_step += 1 loss_to_log = running_loss.item() / num_tokens + + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, @@ -532,6 +553,44 @@ class LoraFinetuningSingleDevice: post_training_job_id=self.job_uuid, path=checkpoint_path, ) + if self.training_config.data_config.validation_dataset_id: + validation_loss, perplexity = await self.validation() + training_metrics = PostTrainingMetric( + epoch=curr_epoch, + train_loss=loss_to_log, + validation_loss=validation_loss, + perplexity=perplexity, + ) + checkpoint.training_metrics = training_metrics checkpoints.append(checkpoint) return (memory_stats, checkpoints) + + async def validation(self) -> Tuple[float, float]: + total_loss = 0.0 + total_tokens = 0 + log.info("Starting validation...") + pbar = tqdm(total=len(self._validation_dataloader)) + for idx, batch in enumerate(self._validation_dataloader): + if idx == 10: + break + torchtune_utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + loss = await self._loss_step(batch) * num_tokens + + total_loss += loss + total_tokens += num_tokens + + pbar.update(1) + pbar.set_description(f"validation step: {idx}") + + mean_loss = total_loss / total_tokens + perplexity = torch.exp(torch.tensor(mean_loss)) + + return mean_loss, perplexity.item()