refine api

This commit is contained in:
Botao Chen 2024-12-03 20:01:27 -08:00
parent 5838b7211d
commit 41cf2bb0a7
3 changed files with 74 additions and 40 deletions

View file

@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl:
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig,
@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl:
request = PostTrainingSFTRequest(
job_uuid=job_uuid,
model=model,
dataset_id=dataset_id,
validation_dataset_id=validation_dataset_id,
algorithm=algorithm,
algorithm_config=algorithm_config,
training_config=training_config,

View file

@ -6,6 +6,7 @@
import logging
import os
import time
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir
@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune.data import InputOutputToMessages, padded_collate_sft
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed or 42)
self.epochs_run = 0
self.total_epochs = request.training_config.n_epochs
self._shuffle = request.training_config.shuffle
self._batch_size = request.training_config.batch_size
self._shuffle = request.training_config.data_config.shuffle
self._batch_size = request.training_config.data_config.batch_size
# this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice:
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
request.training_config.enable_activation_checkpointing
(request.training_config.efficiency_config.enable_activation_checkpointing)
if request.training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(request.training_config.efficiency_config.enable_activation_offloading)
if request.training_config.efficiency_config
else False
)
self._enable_activation_offloading = False
self.datasetio_api = datasetio_api
@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
# temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice:
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
self._use_dora = self.request.algorithm_config.use_dora
self._use_dora = self.request.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device:
model_type = utils.get_model_type(self.model_id)
@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice:
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model
async def _setup_tokenizer(
@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice:
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows():
return await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.dataset_id,
dataset_id=self.request.training_config.data_config.dataset_id,
rows_in_page=-1,
)
@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice:
# Curretly only support instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset(
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
rows,
message_transform=AlpacaToMessages(train_on_input=False),
model_transform=tokenizer,
)
sampler = DistributedSampler(
@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice:
"""
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice:
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
torch.nn.utils.clip_grad_norm_(
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
log.info("Starting checkpoint save...")