From 8132b4e177a5e714453a8d79822189631ae380c4 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 12 Dec 2024 15:33:59 -0800 Subject: [PATCH] refine --- llama_stack/apis/common/training_types.py | 3 +- .../recipes/lora_finetuning_single_device.py | 47 +++++++++++++------ .../experimental-post-training/run.yaml | 15 ------ 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index 4c01a9e00..a9e3cac7c 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from datetime import datetime +from typing import Optional from llama_models.schema_utils import json_schema_type from pydantic import BaseModel @@ -25,4 +26,4 @@ class Checkpoint(BaseModel): epoch: int post_training_job_id: str path: str - training_metric: PostTrainingMetric + training_metric: Optional[PostTrainingMetric] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 42c3f2169..ec72fedb5 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -23,6 +23,7 @@ from llama_stack.providers.inline.post_training.torchtune.common.checkpointer im from torch import nn from torchtune import utils as torchtune_utils from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm from llama_stack.apis.post_training import * # noqa from llama_stack.distribution.utils.model_utils import model_local_dir @@ -186,12 +187,15 @@ class LoraFinetuningSingleDevice: shuffle=self._shuffle, batch_size=self._batch_size, ) - _, self._validation_dataloader = await self._setup_data( - dataset_id=self.training_config.data_config.validation_dataset_id, - tokenizer=self._tokenizer, - shuffle=False, - batch_size=self._batch_size, - ) + + if self.training_config.data_config.validation_dataset_id: + _, self._validation_dataloader = await self._setup_data( + dataset_id=self.training_config.data_config.validation_dataset_id, + tokenizer=self._tokenizer, + shuffle=False, + batch_size=self._batch_size, + ) + log.info("Dataset and Sampler are initialized.") # Number of training steps in each epoch depends on the number of batches produced @@ -461,6 +465,7 @@ class LoraFinetuningSingleDevice: self._training_sampler.set_epoch(curr_epoch) loss_to_log = 0.0 + pbar = tqdm(total=self._steps_per_epoch) for idx, batch in enumerate(self._training_dataloader): if ( self.max_steps_per_epoch is not None @@ -498,6 +503,12 @@ class LoraFinetuningSingleDevice: self.global_step += 1 loss_to_log = running_loss.item() / num_tokens + + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, @@ -524,28 +535,31 @@ class LoraFinetuningSingleDevice: self.epochs_run += 1 log.info("Starting checkpoint save...") checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) - validation_loss, perplexity = await self.validate() - training_metreic = PostTrainingMetric( - epoch=curr_epoch, - train_loss=loss_to_log, - validation_loss=validation_loss, - perplexity=perplexity, - ) checkpoint = Checkpoint( identifier=f"{self.model_id}-sft-{curr_epoch}", created_at=datetime.now(), epoch=curr_epoch, post_training_job_id=self.job_uuid, path=checkpoint_path, - training_metric=training_metreic, ) + if self.training_config.data_config.validation_dataset_id: + validation_loss, perplexity = await self.validation() + training_metreic = PostTrainingMetric( + epoch=curr_epoch, + train_loss=loss_to_log, + validation_loss=validation_loss, + perplexity=perplexity, + ) + checkpoint.training_metric = training_metreic checkpoints.append(checkpoint) return (memory_stats, checkpoints) - async def validate(self) -> Tuple[float, float]: + async def validation(self) -> Tuple[float, float]: total_loss = 0.0 total_tokens = 0 + log.info("Starting validation...") + pbar = tqdm(total=len(self._validation_dataloader)) for idx, batch in enumerate(self._validation_dataloader): if idx == 10: break @@ -562,6 +576,9 @@ class LoraFinetuningSingleDevice: total_loss += loss total_tokens += num_tokens + pbar.update(1) + pbar.set_description(f"validation step: {idx}") + mean_loss = total_loss / total_tokens perplexity = torch.exp(torch.tensor(mean_loss)) diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 5bdcef008..4bdde7aa6 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -49,20 +49,5 @@ datasets: type: string text: type: string - - dataset_id: alpaca_eval - provider_id: huggingface-0 - url: - uri: https://huggingface.co/datasets/causal-lm/code_alpaca - metadata: - path: causal-lm/code_alpaca - name: - split: validation - dataset_schema: - instruction: - type: string - input: - type: string - output: - type: string scoring_fns: [] eval_tasks: []