This commit is contained in:
Botao Chen 2025-01-02 20:17:04 -08:00
parent e3f187fb83
commit f2a0468b2b
2 changed files with 3 additions and 1 deletions

View file

@ -58,6 +58,7 @@ class TrainingConfig(BaseModel):
n_epochs: int n_epochs: int
max_steps_per_epoch: int max_steps_per_epoch: int
gradient_accumulation_steps: int gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig data_config: DataConfig
optimizer_config: OptimizerConfig optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None efficiency_config: Optional[EfficiencyConfig] = None

View file

@ -137,6 +137,7 @@ class LoraFinetuningSingleDevice:
self.global_step = 0 self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
@ -583,7 +584,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...") log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader)) pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader): for idx, batch in enumerate(self._validation_dataloader):
if idx == 10: if idx == self.max_validation_steps:
break break
torchtune_utils.batch_to_device(batch, self._device) torchtune_utils.batch_to_device(batch, self._device)