[post training] define llama stack post training dataset format (#717)

## context
In this PR, we defined 2 llama stack dataset formats (instruct, dialog)

- For instruct dataset format, the column schema will be
[chat_completion_input, expected_answer], which is consistent with the
eval data format. This dataset format is the abstract of single turn QA
style post training data
- For dialog dataset format, the column schema will be [dialog], which
is a list of user messages and assistant messages that interleave
together. During training, the whole list will be the model input and
the loss is calculated on assistant messages only. This dataset format
is the abstract of multi turn chat style post training data

## changes
- defined the 2 llama stack dataset formats
- an adapter to convert llama stack dataset format to torchtune dataset
format
- move dataset format validation to post training level instead of
torchtune level since it's not specific to torchtune
- add localfs as datasetio provider


## test 
instruct format
- use https://huggingface.co/datasets/llamastack/evals as dataset and
the training works as expected
<img width="1443" alt="Screenshot 2025-01-09 at 5 15 14 PM"
src="https://github.com/user-attachments/assets/2c37a936-c67a-4726-90e0-23fa0ba7000f"
/>

- use my generated local dataset and the training works as expected

<img width="1617" alt="Screenshot 2025-01-09 at 5 19 11 PM"
src="https://github.com/user-attachments/assets/0bdccbbf-bac2-472a-a365-15213e49bbfa"
/>


dialog format
- use my generated local dataset and the training works as expected
<img width="1588" alt="Screenshot 2025-01-09 at 5 23 16 PM"
src="https://github.com/user-attachments/assets/893915ba-41a3-4d51-948b-e872060ecede"
/>
This commit is contained in:
Botao Chen 2025-01-14 12:48:49 -08:00 committed by GitHub
parent a174938fbd
commit 25c1d9b037
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 182 additions and 75 deletions

View file

@ -18,7 +18,7 @@ from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
@ -47,6 +47,9 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
@ -129,8 +132,10 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._train_on_input = training_config.data_config.train_on_input
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
@ -354,18 +359,17 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
# Curretly only support alpaca instruct dataset
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
dataset_type=self._data_format.value,
)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)
sampler = DistributedSampler(