From bbe190a0857ed245b58d515cd9710aec01954caa Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 3 Jan 2025 16:27:32 -0800 Subject: [PATCH] refine --- .../apis/post_training/post_training.py | 1 + .../post_training/torchtune/common/utils.py | 32 +++++++++++-------- .../recipes/lora_finetuning_single_device.py | 11 ++----- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index d9dae9e5c..216157d15 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -68,6 +68,7 @@ class TrainingConfig(BaseModel): n_epochs: int max_steps_per_epoch: int gradient_accumulation_steps: int + max_validation_steps: int data_config: DataConfig optimizer_config: OptimizerConfig efficiency_config: Optional[EfficiencyConfig] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 0c2b663d3..b0c5aec42 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -52,9 +52,9 @@ class ModelConfig(BaseModel): class DatasetSchema(BaseModel): alpaca: List[Dict[str, ParamType]] - instruct: Dict[str, ParamType] - chat_sharegpt: Dict[str, ParamType] - chat_openai: Dict[str, ParamType] + instruct: List[Dict[str, ParamType]] + chat_sharegpt: List[Dict[str, ParamType]] + chat_openai: List[Dict[str, ParamType]] MODEL_CONFIGS: Dict[str, ModelConfig] = { @@ -96,16 +96,22 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema( ColumnName.output.value: StringType(), }, ], - instruct={ - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - }, - chat_sharegpt={ - ColumnName.conversations.value: StringType(), - }, - chat_openai={ - ColumnName.messages.value: StringType(), - }, + instruct=[ + { + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + } + ], + chat_sharegpt=[ + { + ColumnName.conversations.value: StringType(), + } + ], + chat_openai=[ + { + ColumnName.messages.value: StringType(), + } + ], ) BuildLoraModelCallable = Callable[..., torch.nn.Module] diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 1bbae6d2e..a3649d5ae 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -126,8 +126,7 @@ class LoraFinetuningSingleDevice: self._output_dir = str(DEFAULT_CHECKPOINT_DIR) - # self.seed = training.set_seed(seed=config.torch_seed) - self.seed = 42 + self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 self.total_epochs = training_config.n_epochs self._data_format = training_config.data_config.data_format @@ -140,6 +139,7 @@ class LoraFinetuningSingleDevice: self.global_step = 0 self._gradient_accumulation_steps = training_config.gradient_accumulation_steps + self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( @@ -366,7 +366,6 @@ class LoraFinetuningSingleDevice: column_map=self._column_map, ) data_transform = await utils.get_data_transform(self._data_format) - print("data_transform", data_transform.__name__) ds = SFTDataset( rows, message_transform=data_transform( @@ -450,10 +449,6 @@ class LoraFinetuningSingleDevice: async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model - # print("tokens", batch["tokens"]) - torch.save(batch["tokens"], "/home/markchen1015/new_alpaca_tokens.pth") - # print("labels", batch["labels"]) - torch.save(batch["labels"], "/home/markchen1015/new_alpaca_labels.pth") labels = batch.pop("labels") # run model with self.activations_handling_ctx: @@ -595,7 +590,7 @@ class LoraFinetuningSingleDevice: log.info("Starting validation...") pbar = tqdm(total=len(self._validation_dataloader)) for idx, batch in enumerate(self._validation_dataloader): - if idx == 10: + if idx == self.max_validation_steps: break torchtune_utils.batch_to_device(batch, self._device)