init commit

This commit is contained in:
Botao Chen 2024-12-10 19:39:58 -08:00
parent e5993c565e
commit dc3d9d7720
3 changed files with 68 additions and 10 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
@ -26,4 +25,4 @@ class Checkpoint(BaseModel):
epoch: int epoch: int
post_training_job_id: str post_training_job_id: str
path: str path: str
training_metrics: Optional[PostTrainingMetric] = None training_metric: PostTrainingMetric

View file

@ -180,11 +180,18 @@ class LoraFinetuningSingleDevice:
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.") log.info("Loss is initialized.")
self._sampler, self._dataloader = await self._setup_data( self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer, tokenizer=self._tokenizer,
shuffle=self._shuffle, shuffle=self._shuffle,
batch_size=self._batch_size, batch_size=self._batch_size,
) )
_, self._validation_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.validation_dataset_id,
tokenizer=self._tokenizer,
shuffle=False,
batch_size=self._batch_size,
)
log.info("Dataset and Sampler are initialized.") log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced # Number of training steps in each epoch depends on the number of batches produced
@ -192,7 +199,7 @@ class LoraFinetuningSingleDevice:
# for logging and tracking training state. This should be computed after the dataloader # for logging and tracking training state. This should be computed after the dataloader
# has been setup # has been setup
self._steps_per_epoch = ( self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps len(self._training_dataloader) // self._gradient_accumulation_steps
) )
if ( if (
self.max_steps_per_epoch is not None self.max_steps_per_epoch is not None
@ -311,15 +318,19 @@ class LoraFinetuningSingleDevice:
return optimizer return optimizer
async def _setup_data( async def _setup_data(
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int self,
dataset_id: str,
tokenizer: Llama3Tokenizer,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(): async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.get_rows_paginated(
dataset_id=self.training_config.data_config.dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
all_rows = await fetch_rows() all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows rows = all_rows.rows
# Curretly only support alpaca instruct dataset # Curretly only support alpaca instruct dataset
@ -447,9 +458,10 @@ class LoraFinetuningSingleDevice:
metric_logger = DiskLogger( metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
) )
self._sampler.set_epoch(curr_epoch) self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
for idx, batch in enumerate(self._dataloader): for idx, batch in enumerate(self._training_dataloader):
if ( if (
self.max_steps_per_epoch is not None self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps) and (idx // self._gradient_accumulation_steps)
@ -512,13 +524,45 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1 self.epochs_run += 1
log.info("Starting checkpoint save...") log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
validation_loss, perplexity = await self.validate()
training_metreic = PostTrainingMetric(
epoch=curr_epoch,
train_loss=loss_to_log,
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint = Checkpoint( checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}", identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(), created_at=datetime.now(),
epoch=curr_epoch, epoch=curr_epoch,
post_training_job_id=self.job_uuid, post_training_job_id=self.job_uuid,
path=checkpoint_path, path=checkpoint_path,
training_metric=training_metreic,
) )
checkpoints.append(checkpoint) checkpoints.append(checkpoint)
return (memory_stats, checkpoints) return (memory_stats, checkpoints)
async def validate(self) -> Tuple[float, float]:
total_loss = 0.0
total_tokens = 0
for idx, batch in enumerate(self._validation_dataloader):
if idx == 10:
break
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
loss = await self._loss_step(batch) * num_tokens
total_loss += loss
total_tokens += num_tokens
mean_loss = total_loss / total_tokens
perplexity = torch.exp(torch.tensor(mean_loss))
return mean_loss, perplexity.item()

View file

@ -49,5 +49,20 @@ datasets:
type: string type: string
text: text:
type: string type: string
- dataset_id: alpaca_eval
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/causal-lm/code_alpaca
metadata:
path: causal-lm/code_alpaca
name:
split: validation
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []