From 4320b0ebb2b834f237c074a4539d1b1268c15854 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 3 Jan 2025 08:43:24 -0800 Subject: [PATCH] [Post training] make validation steps configurable (#715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## what does this PR do? The current code hardcode the validation steps to run (forgot to change it after testing). in this PR, we make it configurable by training config ## test On client side, issue a post training request with 20 validation steps, server side logging shows that it runs 20 validation steps successfully Screenshot 2025-01-02 at 8 21 06 PM --- llama_stack/apis/post_training/post_training.py | 1 + .../torchtune/recipes/lora_finetuning_single_device.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 1c2d2d6e2..8e1edbe87 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -58,6 +58,7 @@ class TrainingConfig(BaseModel): n_epochs: int max_steps_per_epoch: int gradient_accumulation_steps: int + max_validation_steps: int data_config: DataConfig optimizer_config: OptimizerConfig efficiency_config: Optional[EfficiencyConfig] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 1b6c508a7..a2ef1c5dd 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -137,6 +137,7 @@ class LoraFinetuningSingleDevice: self.global_step = 0 self._gradient_accumulation_steps = training_config.gradient_accumulation_steps + self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( @@ -583,7 +584,7 @@ class LoraFinetuningSingleDevice: log.info("Starting validation...") pbar = tqdm(total=len(self._validation_dataloader)) for idx, batch in enumerate(self._validation_dataloader): - if idx == 10: + if idx == self.max_validation_steps: break torchtune_utils.batch_to_device(batch, self._device)