This commit is contained in:
Botao Chen 2025-01-09 16:57:48 -08:00
parent 49e1a25343
commit 929e7c0e69
6 changed files with 9 additions and 25 deletions

View file

@ -30,7 +30,7 @@ class OptimizerType(Enum):
@json_schema_type
class DatasetFormat(Enum):
instruct = "instruct"
chat = "chat"
chat = "dialog"
@json_schema_type

View file

@ -23,7 +23,7 @@ EXPECTED_DATASET_SCHEMA = {
ColumnName.expected_answer.value: StringType(),
}
],
"chat": [
"dialog": [
{
ColumnName.dialog.value: StringType(),
}

View file

@ -49,7 +49,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
DATA_FORMATS: Dict[str, Transform] = {
"instruct": InputOutputToMessages,
"chat": ShareGPTToMessages,
"dialog": ShareGPTToMessages,
}

View file

@ -30,6 +30,7 @@ def llama_stack_instruct_to_torchtune_instruct(
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
input = input_message["content"]
output = sample[ColumnName.expected_answer.value]
return {

View file

@ -135,7 +135,7 @@ class LoraFinetuningSingleDevice:
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._training_on_input = training_config.data_config.training_on_input
self._train_on_input = training_config.data_config.train_on_input
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
@ -367,7 +367,7 @@ class LoraFinetuningSingleDevice:
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=data_transform(train_on_input=self._training_on_input),
message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)

View file

@ -29,8 +29,8 @@ providers:
provider_type: inline::basic
config: {}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
config: {}
telemetry:
- provider_id: meta-reference
@ -68,23 +68,6 @@ metadata_store:
models: []
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
datasets: []
scoring_fns: []
eval_tasks: []