From 929e7c0e69cc7b0e15a6fda420a72d04a1d623f1 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 9 Jan 2025 16:57:48 -0800 Subject: [PATCH] refine --- .../apis/post_training/post_training.py | 2 +- .../inline/post_training/common/validator.py | 2 +- .../post_training/torchtune/common/utils.py | 2 +- .../torchtune/datasets/format_adapter.py | 1 + .../recipes/lora_finetuning_single_device.py | 4 ++-- .../experimental-post-training/run.yaml | 23 +++---------------- 6 files changed, 9 insertions(+), 25 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 9bcaa24ac..5385a39b0 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -30,7 +30,7 @@ class OptimizerType(Enum): @json_schema_type class DatasetFormat(Enum): instruct = "instruct" - chat = "chat" + chat = "dialog" @json_schema_type diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py index 2a7f67fd5..eac6b3302 100644 --- a/llama_stack/providers/inline/post_training/common/validator.py +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -23,7 +23,7 @@ EXPECTED_DATASET_SCHEMA = { ColumnName.expected_answer.value: StringType(), } ], - "chat": [ + "dialog": [ { ColumnName.dialog.value: StringType(), } diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 198d2f28e..a838f8d72 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -49,7 +49,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { DATA_FORMATS: Dict[str, Transform] = { "instruct": InputOutputToMessages, - "chat": ShareGPTToMessages, + "dialog": ShareGPTToMessages, } diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py index d36244464..b4dfbb3c1 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py @@ -30,6 +30,7 @@ def llama_stack_instruct_to_torchtune_instruct( input_message = input_messages[0] assert "content" in input_message, "content not found in input message" + input = input_message["content"] output = sample[ColumnName.expected_answer.value] return { diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 95e2bc220..8257da9e2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -135,7 +135,7 @@ class LoraFinetuningSingleDevice: self._data_format = training_config.data_config.data_format self._shuffle = training_config.data_config.shuffle self._batch_size = training_config.data_config.batch_size - self._training_on_input = training_config.data_config.training_on_input + self._train_on_input = training_config.data_config.train_on_input # this is important for debugging purpose self.max_steps_per_epoch = training_config.max_steps_per_epoch @@ -367,7 +367,7 @@ class LoraFinetuningSingleDevice: data_transform = await utils.get_data_transform(self._data_format) ds = SFTDataset( rows, - message_transform=data_transform(train_on_input=self._training_on_input), + message_transform=data_transform(train_on_input=self._train_on_input), model_transform=tokenizer, dataset_type=self._data_format.value, ) diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index a654c375e..308f03a2e 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -29,8 +29,8 @@ providers: provider_type: inline::basic config: {} datasetio: - - provider_id: huggingface-0 - provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs config: {} telemetry: - provider_id: meta-reference @@ -68,23 +68,6 @@ metadata_store: models: [] shields: [] memory_banks: [] -datasets: - - dataset_id: alpaca - provider_id: huggingface-0 - url: - uri: https://huggingface.co/datasets/tatsu-lab/alpaca - metadata: - path: tatsu-lab/alpaca - name: - split: train - dataset_schema: - instruction: - type: string - input: - type: string - output: - type: string - text: - type: string +datasets: [] scoring_fns: [] eval_tasks: []