diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 1c2d2d6e2..8e1edbe87 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -58,6 +58,7 @@ class TrainingConfig(BaseModel): n_epochs: int max_steps_per_epoch: int gradient_accumulation_steps: int + max_validation_steps: int data_config: DataConfig optimizer_config: OptimizerConfig efficiency_config: Optional[EfficiencyConfig] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 1b6c508a7..a2ef1c5dd 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -137,6 +137,7 @@ class LoraFinetuningSingleDevice: self.global_step = 0 self._gradient_accumulation_steps = training_config.gradient_accumulation_steps + self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( @@ -583,7 +584,7 @@ class LoraFinetuningSingleDevice: log.info("Starting validation...") pbar = tqdm(total=len(self._validation_dataloader)) for idx, batch in enumerate(self._validation_dataloader): - if idx == 10: + if idx == self.max_validation_steps: break torchtune_utils.batch_to_device(batch, self._device)