From dc3d9d7720c3044937501de788ad2c2a8940fb28 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Tue, 10 Dec 2024 19:39:58 -0800 Subject: [PATCH] init commit --- llama_stack/apis/common/training_types.py | 3 +- .../recipes/lora_finetuning_single_device.py | 60 ++++++++++++++++--- .../experimental-post-training/run.yaml | 15 +++++ 3 files changed, 68 insertions(+), 10 deletions(-) diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index b4bd1b0c6..4c01a9e00 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from datetime import datetime -from typing import Optional from llama_models.schema_utils import json_schema_type from pydantic import BaseModel @@ -26,4 +25,4 @@ class Checkpoint(BaseModel): epoch: int post_training_job_id: str path: str - training_metrics: Optional[PostTrainingMetric] = None + training_metric: PostTrainingMetric diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 734b437b4..42c3f2169 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -180,11 +180,18 @@ class LoraFinetuningSingleDevice: self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) log.info("Loss is initialized.") - self._sampler, self._dataloader = await self._setup_data( + self._training_sampler, self._training_dataloader = await self._setup_data( + dataset_id=self.training_config.data_config.dataset_id, tokenizer=self._tokenizer, shuffle=self._shuffle, batch_size=self._batch_size, ) + _, self._validation_dataloader = await self._setup_data( + dataset_id=self.training_config.data_config.validation_dataset_id, + tokenizer=self._tokenizer, + shuffle=False, + batch_size=self._batch_size, + ) log.info("Dataset and Sampler are initialized.") # Number of training steps in each epoch depends on the number of batches produced @@ -192,7 +199,7 @@ class LoraFinetuningSingleDevice: # for logging and tracking training state. This should be computed after the dataloader # has been setup self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps + len(self._training_dataloader) // self._gradient_accumulation_steps ) if ( self.max_steps_per_epoch is not None @@ -311,15 +318,19 @@ class LoraFinetuningSingleDevice: return optimizer async def _setup_data( - self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int + self, + dataset_id: str, + tokenizer: Llama3Tokenizer, + shuffle: bool, + batch_size: int, ) -> Tuple[DistributedSampler, DataLoader]: - async def fetch_rows(): + async def fetch_rows(dataset_id: str): return await self.datasetio_api.get_rows_paginated( - dataset_id=self.training_config.data_config.dataset_id, + dataset_id=dataset_id, rows_in_page=-1, ) - all_rows = await fetch_rows() + all_rows = await fetch_rows(dataset_id) rows = all_rows.rows # Curretly only support alpaca instruct dataset @@ -447,9 +458,10 @@ class LoraFinetuningSingleDevice: metric_logger = DiskLogger( log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" ) - self._sampler.set_epoch(curr_epoch) + self._training_sampler.set_epoch(curr_epoch) + loss_to_log = 0.0 - for idx, batch in enumerate(self._dataloader): + for idx, batch in enumerate(self._training_dataloader): if ( self.max_steps_per_epoch is not None and (idx // self._gradient_accumulation_steps) @@ -512,13 +524,45 @@ class LoraFinetuningSingleDevice: self.epochs_run += 1 log.info("Starting checkpoint save...") checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) + validation_loss, perplexity = await self.validate() + training_metreic = PostTrainingMetric( + epoch=curr_epoch, + train_loss=loss_to_log, + validation_loss=validation_loss, + perplexity=perplexity, + ) checkpoint = Checkpoint( identifier=f"{self.model_id}-sft-{curr_epoch}", created_at=datetime.now(), epoch=curr_epoch, post_training_job_id=self.job_uuid, path=checkpoint_path, + training_metric=training_metreic, ) checkpoints.append(checkpoint) return (memory_stats, checkpoints) + + async def validate(self) -> Tuple[float, float]: + total_loss = 0.0 + total_tokens = 0 + for idx, batch in enumerate(self._validation_dataloader): + if idx == 10: + break + torchtune_utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + loss = await self._loss_step(batch) * num_tokens + + total_loss += loss + total_tokens += num_tokens + + mean_loss = total_loss / total_tokens + perplexity = torch.exp(torch.tensor(mean_loss)) + + return mean_loss, perplexity.item() diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 4bdde7aa6..5bdcef008 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -49,5 +49,20 @@ datasets: type: string text: type: string + - dataset_id: alpaca_eval + provider_id: huggingface-0 + url: + uri: https://huggingface.co/datasets/causal-lm/code_alpaca + metadata: + path: causal-lm/code_alpaca + name: + split: validation + dataset_schema: + instruction: + type: string + input: + type: string + output: + type: string scoring_fns: [] eval_tasks: []