This commit is contained in:
Botao Chen 2024-12-12 15:33:59 -08:00
parent dc3d9d7720
commit 8132b4e177
3 changed files with 34 additions and 31 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@ -25,4 +26,4 @@ class Checkpoint(BaseModel):
epoch: int
post_training_job_id: str
path: str
training_metric: PostTrainingMetric
training_metric: Optional[PostTrainingMetric] = None

View file

@ -23,6 +23,7 @@ from llama_stack.providers.inline.post_training.torchtune.common.checkpointer im
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
@ -186,12 +187,15 @@ class LoraFinetuningSingleDevice:
shuffle=self._shuffle,
batch_size=self._batch_size,
)
_, self._validation_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.validation_dataset_id,
tokenizer=self._tokenizer,
shuffle=False,
batch_size=self._batch_size,
)
if self.training_config.data_config.validation_dataset_id:
_, self._validation_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.validation_dataset_id,
tokenizer=self._tokenizer,
shuffle=False,
batch_size=self._batch_size,
)
log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced
@ -461,6 +465,7 @@ class LoraFinetuningSingleDevice:
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
@ -498,6 +503,12 @@ class LoraFinetuningSingleDevice:
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
@ -524,28 +535,31 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
validation_loss, perplexity = await self.validate()
training_metreic = PostTrainingMetric(
epoch=curr_epoch,
train_loss=loss_to_log,
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
training_metric=training_metreic,
)
if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation()
training_metreic = PostTrainingMetric(
epoch=curr_epoch,
train_loss=loss_to_log,
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint.training_metric = training_metreic
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)
async def validate(self) -> Tuple[float, float]:
async def validation(self) -> Tuple[float, float]:
total_loss = 0.0
total_tokens = 0
log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader):
if idx == 10:
break
@ -562,6 +576,9 @@ class LoraFinetuningSingleDevice:
total_loss += loss
total_tokens += num_tokens
pbar.update(1)
pbar.set_description(f"validation step: {idx}")
mean_loss = total_loss / total_tokens
perplexity = torch.exp(torch.tensor(mean_loss))

View file

@ -49,20 +49,5 @@ datasets:
type: string
text:
type: string
- dataset_id: alpaca_eval
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/causal-lm/code_alpaca
metadata:
path: causal-lm/code_alpaca
name:
split: validation
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
scoring_fns: []
eval_tasks: []