This commit is contained in:
Botao Chen 2025-01-09 16:57:48 -08:00
parent 49e1a25343
commit 929e7c0e69
6 changed files with 9 additions and 25 deletions

View file

@ -23,7 +23,7 @@ EXPECTED_DATASET_SCHEMA = {
ColumnName.expected_answer.value: StringType(),
}
],
"chat": [
"dialog": [
{
ColumnName.dialog.value: StringType(),
}

View file

@ -49,7 +49,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
DATA_FORMATS: Dict[str, Transform] = {
"instruct": InputOutputToMessages,
"chat": ShareGPTToMessages,
"dialog": ShareGPTToMessages,
}

View file

@ -30,6 +30,7 @@ def llama_stack_instruct_to_torchtune_instruct(
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
input = input_message["content"]
output = sample[ColumnName.expected_answer.value]
return {

View file

@ -135,7 +135,7 @@ class LoraFinetuningSingleDevice:
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._training_on_input = training_config.data_config.training_on_input
self._train_on_input = training_config.data_config.train_on_input
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
@ -367,7 +367,7 @@ class LoraFinetuningSingleDevice:
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=data_transform(train_on_input=self._training_on_input),
message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)