forked from phoenix-oss/llama-stack-mirror
[3/n][torchtune integration] add validation logic (#600)
## What does this PR do? - add validation logic in SFT recipe (validation loss and perplexity) - add progress bar in both training and validation to better track the progress on server side (eval has the similar logic) ## Test Plan validation logic shows up in the Checkpoint training_metric part <img width="799" alt="Screenshot 2024-12-12 at 3 21 52 PM" src="https://github.com/user-attachments/assets/36330ffe-0555-4b2d-93f0-9487dfdf7b4e" /> progress bar shows up as <img width="476" alt="Screenshot 2024-12-12 at 3 38 11 PM" src="https://github.com/user-attachments/assets/77306fa2-cb9c-460f-8efc-b41bbe424a7d" /> expected
This commit is contained in:
parent
c294a01c4b
commit
20383bfea5
1 changed files with 68 additions and 9 deletions
|
@ -23,6 +23,7 @@ from llama_stack.providers.inline.post_training.torchtune.common.checkpointer im
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
|
from tqdm import tqdm
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
@ -185,11 +186,21 @@ class LoraFinetuningSingleDevice:
|
||||||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||||
log.info("Loss is initialized.")
|
log.info("Loss is initialized.")
|
||||||
|
|
||||||
self._sampler, self._dataloader = await self._setup_data(
|
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||||
|
dataset_id=self.training_config.data_config.dataset_id,
|
||||||
tokenizer=self._tokenizer,
|
tokenizer=self._tokenizer,
|
||||||
shuffle=self._shuffle,
|
shuffle=self._shuffle,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
|
_, self._validation_dataloader = await self._setup_data(
|
||||||
|
dataset_id=self.training_config.data_config.validation_dataset_id,
|
||||||
|
tokenizer=self._tokenizer,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=self._batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
log.info("Dataset and Sampler are initialized.")
|
log.info("Dataset and Sampler are initialized.")
|
||||||
|
|
||||||
# Number of training steps in each epoch depends on the number of batches produced
|
# Number of training steps in each epoch depends on the number of batches produced
|
||||||
|
@ -197,7 +208,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# for logging and tracking training state. This should be computed after the dataloader
|
# for logging and tracking training state. This should be computed after the dataloader
|
||||||
# has been setup
|
# has been setup
|
||||||
self._steps_per_epoch = (
|
self._steps_per_epoch = (
|
||||||
len(self._dataloader) // self._gradient_accumulation_steps
|
len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.max_steps_per_epoch is not None
|
self.max_steps_per_epoch is not None
|
||||||
|
@ -316,17 +327,19 @@ class LoraFinetuningSingleDevice:
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
async def _setup_data(
|
async def _setup_data(
|
||||||
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
tokenizer: Llama3Tokenizer,
|
||||||
|
shuffle: bool,
|
||||||
|
batch_size: int,
|
||||||
) -> Tuple[DistributedSampler, DataLoader]:
|
) -> Tuple[DistributedSampler, DataLoader]:
|
||||||
dataset_id = self.training_config.data_config.dataset_id
|
async def fetch_rows(dataset_id: str):
|
||||||
|
|
||||||
async def fetch_rows():
|
|
||||||
return await self.datasetio_api.get_rows_paginated(
|
return await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_rows = await fetch_rows()
|
all_rows = await fetch_rows(dataset_id)
|
||||||
rows = all_rows.rows
|
rows = all_rows.rows
|
||||||
|
|
||||||
# Curretly only support alpaca instruct dataset
|
# Curretly only support alpaca instruct dataset
|
||||||
|
@ -460,9 +473,11 @@ class LoraFinetuningSingleDevice:
|
||||||
metric_logger = DiskLogger(
|
metric_logger = DiskLogger(
|
||||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||||
)
|
)
|
||||||
self._sampler.set_epoch(curr_epoch)
|
self._training_sampler.set_epoch(curr_epoch)
|
||||||
|
loss_to_log = 0.0
|
||||||
|
|
||||||
for idx, batch in enumerate(self._dataloader):
|
pbar = tqdm(total=self._steps_per_epoch)
|
||||||
|
for idx, batch in enumerate(self._training_dataloader):
|
||||||
if (
|
if (
|
||||||
self.max_steps_per_epoch is not None
|
self.max_steps_per_epoch is not None
|
||||||
and (idx // self._gradient_accumulation_steps)
|
and (idx // self._gradient_accumulation_steps)
|
||||||
|
@ -499,6 +514,12 @@ class LoraFinetuningSingleDevice:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
loss_to_log = running_loss.item() / num_tokens
|
loss_to_log = running_loss.item() / num_tokens
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(
|
||||||
|
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||||
|
)
|
||||||
|
|
||||||
time_per_step = time.perf_counter() - t0
|
time_per_step = time.perf_counter() - t0
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"loss": loss_to_log,
|
"loss": loss_to_log,
|
||||||
|
@ -532,6 +553,44 @@ class LoraFinetuningSingleDevice:
|
||||||
post_training_job_id=self.job_uuid,
|
post_training_job_id=self.job_uuid,
|
||||||
path=checkpoint_path,
|
path=checkpoint_path,
|
||||||
)
|
)
|
||||||
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
|
validation_loss, perplexity = await self.validation()
|
||||||
|
training_metrics = PostTrainingMetric(
|
||||||
|
epoch=curr_epoch,
|
||||||
|
train_loss=loss_to_log,
|
||||||
|
validation_loss=validation_loss,
|
||||||
|
perplexity=perplexity,
|
||||||
|
)
|
||||||
|
checkpoint.training_metrics = training_metrics
|
||||||
checkpoints.append(checkpoint)
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
return (memory_stats, checkpoints)
|
return (memory_stats, checkpoints)
|
||||||
|
|
||||||
|
async def validation(self) -> Tuple[float, float]:
|
||||||
|
total_loss = 0.0
|
||||||
|
total_tokens = 0
|
||||||
|
log.info("Starting validation...")
|
||||||
|
pbar = tqdm(total=len(self._validation_dataloader))
|
||||||
|
for idx, batch in enumerate(self._validation_dataloader):
|
||||||
|
if idx == 10:
|
||||||
|
break
|
||||||
|
torchtune_utils.batch_to_device(batch, self._device)
|
||||||
|
|
||||||
|
# Calculate the number of unmasked tokens in the current batch
|
||||||
|
# and increment the total number of tokens seen in the step
|
||||||
|
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||||
|
|
||||||
|
# Loss is normalized by default so we multiply by the number of tokens
|
||||||
|
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||||
|
loss = await self._loss_step(batch) * num_tokens
|
||||||
|
|
||||||
|
total_loss += loss
|
||||||
|
total_tokens += num_tokens
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(f"validation step: {idx}")
|
||||||
|
|
||||||
|
mean_loss = total_loss / total_tokens
|
||||||
|
perplexity = torch.exp(torch.tensor(mean_loss))
|
||||||
|
|
||||||
|
return mean_loss, perplexity.item()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue