forked from phoenix-oss/llama-stack-mirror
[Post training] make validation steps configurable (#715)
## what does this PR do? The current code hardcode the validation steps to run (forgot to change it after testing). in this PR, we make it configurable by training config ## test On client side, issue a post training request with 20 validation steps, server side logging shows that it runs 20 validation steps successfully <img width="1128" alt="Screenshot 2025-01-02 at 8 21 06 PM" src="https://github.com/user-attachments/assets/7a757516-c6ba-41d4-85c5-361a80ecf46e" />
This commit is contained in:
parent
f450a0fd32
commit
4320b0ebb2
2 changed files with 3 additions and 1 deletions
|
@ -58,6 +58,7 @@ class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int
|
max_steps_per_epoch: int
|
||||||
gradient_accumulation_steps: int
|
gradient_accumulation_steps: int
|
||||||
|
max_validation_steps: int
|
||||||
data_config: DataConfig
|
data_config: DataConfig
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: OptimizerConfig
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: Optional[EfficiencyConfig] = None
|
||||||
|
|
|
@ -137,6 +137,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||||
|
self.max_validation_steps = training_config.max_validation_steps
|
||||||
|
|
||||||
self._clip_grad_norm = 1.0
|
self._clip_grad_norm = 1.0
|
||||||
self._enable_activation_checkpointing = (
|
self._enable_activation_checkpointing = (
|
||||||
|
@ -583,7 +584,7 @@ class LoraFinetuningSingleDevice:
|
||||||
log.info("Starting validation...")
|
log.info("Starting validation...")
|
||||||
pbar = tqdm(total=len(self._validation_dataloader))
|
pbar = tqdm(total=len(self._validation_dataloader))
|
||||||
for idx, batch in enumerate(self._validation_dataloader):
|
for idx, batch in enumerate(self._validation_dataloader):
|
||||||
if idx == 10:
|
if idx == self.max_validation_steps:
|
||||||
break
|
break
|
||||||
torchtune_utils.batch_to_device(batch, self._device)
|
torchtune_utils.batch_to_device(batch, self._device)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue